archimedes.tree.register_struct¶
- archimedes.tree.register_struct(
- ty: Any,
- to_iter: Callable,
- from_iter: Callable,
Register a custom type as a tree-compatible dataclass.
This function allows custom types to be recognized and processed by Archimedes’ tree functions. You need to provide functions that convert between your type and its components.
- Parameters:
ty (type) – The type to register as a tree node.
to_iter (callable) –
A function that accepts an instance of type ty and returns a tuple of (children, aux_data), where:
children is an iterable of the tree node’s children
aux_data is any auxiliary metadata needed to reconstruct the node but not part of the tree structure itself
from_iter (callable) – A function that accepts aux_data and an iterable of children and returns a reconstructed instance of type ty.
- Return type:
None
Notes
The to_iter function should extract the relevant parts of your data structure, and the from_iter function should be able to reconstruct it exactly.
Usually, instead of using this function directly, you’ll want to use the @struct decorator for classes, which automatically handles registration for dataclass-like structures. This function is used internally to register the decorated classes. It is also available as an alternative interface for low-level control of flattening/unflattening behavior and static data for custom classes.
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Define a custom container class >>> class Point3D: ... def __init__(self, x, y, z): ... self.x = x ... self.y = y ... self.z = z ... ... def __repr__(self): ... return f"Point3D({self.x}, {self.y}, {self.z})" >>> >>> # Define functions to convert to/from iterables >>> def point_to_iter(point): ... children = (point.x, point.y, point.z) ... aux_data = None # No static auxiliary data needed ... return children, aux_data >>> >>> def point_from_iter(aux_data, children): ... x, y, z = children ... return Point3D(x, y, z) >>> >>> # Register the class as a struct >>> arc.tree.register_struct(Point3D, point_to_iter, point_from_iter) >>> >>> # Now Point3D works with tree operations >>> p = Point3D(np.array([1.0]), np.array([2.0]), np.array([3.0])) >>> doubled = arc.tree.map(lambda x: x * 2, p) >>> print(doubled) Point3D([2.], [4.], [6.])
See also
struct
Decorator for creating tree-compatible dataclasses
register_dataclass
Register a dataclass as a struct