archimedes.tree.register_struct¶

archimedes.tree.register_struct(
ty: Any,
to_iter: Callable,
from_iter: Callable,
) None¶

Register a custom type as a tree-compatible dataclass.

This function allows custom types to be recognized and processed by Archimedes’ tree functions. You need to provide functions that convert between your type and its components.

Parameters:
  • ty (type) – The type to register as a tree node.

  • to_iter (callable) –

    A function that accepts an instance of type ty and returns a tuple of (children, aux_data), where:

    • children is an iterable of the tree node’s children

    • aux_data is any auxiliary metadata needed to reconstruct the node but not part of the tree structure itself

  • from_iter (callable) – A function that accepts aux_data and an iterable of children and returns a reconstructed instance of type ty.

Return type:

None

Notes

The to_iter function should extract the relevant parts of your data structure, and the from_iter function should be able to reconstruct it exactly.

Usually, instead of using this function directly, you’ll want to use the @struct decorator for classes, which automatically handles registration for dataclass-like structures. This function is used internally to register the decorated classes. It is also available as an alternative interface for low-level control of flattening/unflattening behavior and static data for custom classes.

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Define a custom container class
>>> class Point3D:
...     def __init__(self, x, y, z):
...         self.x = x
...         self.y = y
...         self.z = z
...
...     def __repr__(self):
...         return f"Point3D({self.x}, {self.y}, {self.z})"
>>>
>>> # Define functions to convert to/from iterables
>>> def point_to_iter(point):
...     children = (point.x, point.y, point.z)
...     aux_data = None  # No static auxiliary data needed
...     return children, aux_data
>>>
>>> def point_from_iter(aux_data, children):
...     x, y, z = children
...     return Point3D(x, y, z)
>>>
>>> # Register the class as a struct
>>> arc.tree.register_struct(Point3D, point_to_iter, point_from_iter)
>>>
>>> # Now Point3D works with tree operations
>>> p = Point3D(np.array([1.0]), np.array([2.0]), np.array([3.0]))
>>> doubled = arc.tree.map(lambda x: x * 2, p)
>>> print(doubled)
Point3D([2.], [4.], [6.])

See also

struct

Decorator for creating tree-compatible dataclasses

register_dataclass

Register a dataclass as a struct