archimedes.tree.register_pytree_node¶
- archimedes.tree.register_pytree_node(ty: Any, to_iter: Callable, from_iter: Callable) None ¶
Register a custom type as a pytree node.
This function allows custom types to be recognized and processed by Archimedes’ pytree functions. You need to provide functions that convert between your type and its components.
- Parameters:
ty (type) – The type to register as a pytree node.
to_iter (callable) –
A function that accepts an instance of type ty and returns a tuple of (children, aux_data), where:
children is an iterable of the pytree node’s children
aux_data is any auxiliary metadata needed to reconstruct the node but not part of the pytree structure itself
from_iter (callable) – A function that accepts aux_data and an iterable of children and returns a reconstructed instance of type ty.
- Return type:
None
Notes
When to use:
When you have custom container types that should be traversed by pytree operations
To enable pytree transformations on your own data structures
When creating reusable components that need to be compatible with Archimedes’ pytree-based operations
The to_iter function should extract the relevant parts of your data structure, and the from_iter function should be able to reconstruct it exactly.
Usually, instead of using this function directly, you’ll want to use the struct.pytree_node decorator for classes, which automatically handles registration for dataclass-like structures. This function is used internally to register the decorated classes. It is also available as an alternative interface for low-level control of flattening/unflattening behavior and static data for custom classes.
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Define a custom container class >>> class Point3D: ... def __init__(self, x, y, z): ... self.x = x ... self.y = y ... self.z = z ... ... def __repr__(self): ... return f"Point3D({self.x}, {self.y}, {self.z})" >>> >>> # Define functions to convert to/from iterables >>> def point_to_iter(point): ... children = (point.x, point.y, point.z) ... aux_data = None # No static auxiliary data needed ... return children, aux_data >>> >>> def point_from_iter(aux_data, children): ... x, y, z = children ... return Point3D(x, y, z) >>> >>> # Register the class as a pytree node >>> arc.tree.register_pytree_node(Point3D, point_to_iter, point_from_iter) >>> >>> # Now Point3D works with pytree operations >>> p = Point3D(np.array([1.0]), np.array([2.0]), np.array([3.0])) >>> doubled = arc.tree.map(lambda x: x * 2, p) >>> print(doubled) Point3D([2.], [4.], [6.])
See also
struct.pytree_node
Decorator for creating pytree-compatible classes
register_dataclass
Register a dataclass as a pytree node