archimedes.tree.register_pytree_node¶

archimedes.tree.register_pytree_node(ty: Any, to_iter: Callable, from_iter: Callable) None¶

Register a custom type as a pytree node.

This function allows custom types to be recognized and processed by Archimedes’ pytree functions. You need to provide functions that convert between your type and its components.

Parameters:
  • ty (type) – The type to register as a pytree node.

  • to_iter (callable) –

    A function that accepts an instance of type ty and returns a tuple of (children, aux_data), where:

    • children is an iterable of the pytree node’s children

    • aux_data is any auxiliary metadata needed to reconstruct the node but not part of the pytree structure itself

  • from_iter (callable) – A function that accepts aux_data and an iterable of children and returns a reconstructed instance of type ty.

Return type:

None

Notes

When to use:

  • When you have custom container types that should be traversed by pytree operations

  • To enable pytree transformations on your own data structures

  • When creating reusable components that need to be compatible with Archimedes’ pytree-based operations

The to_iter function should extract the relevant parts of your data structure, and the from_iter function should be able to reconstruct it exactly.

Usually, instead of using this function directly, you’ll want to use the struct.pytree_node decorator for classes, which automatically handles registration for dataclass-like structures. This function is used internally to register the decorated classes. It is also available as an alternative interface for low-level control of flattening/unflattening behavior and static data for custom classes.

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Define a custom container class
>>> class Point3D:
...     def __init__(self, x, y, z):
...         self.x = x
...         self.y = y
...         self.z = z
...
...     def __repr__(self):
...         return f"Point3D({self.x}, {self.y}, {self.z})"
>>>
>>> # Define functions to convert to/from iterables
>>> def point_to_iter(point):
...     children = (point.x, point.y, point.z)
...     aux_data = None  # No static auxiliary data needed
...     return children, aux_data
>>>
>>> def point_from_iter(aux_data, children):
...     x, y, z = children
...     return Point3D(x, y, z)
>>>
>>> # Register the class as a pytree node
>>> arc.tree.register_pytree_node(Point3D, point_to_iter, point_from_iter)
>>>
>>> # Now Point3D works with pytree operations
>>> p = Point3D(np.array([1.0]), np.array([2.0]), np.array([3.0]))
>>> doubled = arc.tree.map(lambda x: x * 2, p)
>>> print(doubled)
Point3D([2.], [4.], [6.])

See also

struct.pytree_node

Decorator for creating pytree-compatible classes

register_dataclass

Register a dataclass as a pytree node