archimedes.tree.register_struct¶

archimedes.tree.register_struct(
ty: Any,
to_iter: Callable,
from_iter: Callable,
) None¶

Register a custom type as a tree-compatible node.

This function allows custom types to be recognized and processed by the tree functions. You provide functions that convert between your type and its components.

Parameters:
  • ty (type) – The type to register as a tree node.

  • to_iter (callable) –

    A function that accepts an instance of type ty and returns a tuple of (children, aux_data), where:

    • children is an iterable of the tree node’s children

    • aux_data is any auxiliary metadata needed to reconstruct the node but not part of the tree structure itself

  • from_iter (callable) – A function that accepts aux_data and an iterable of children and returns a reconstructed instance of type ty.

Return type:

None

Notes

The to_iter function should extract the relevant parts of your data structure, and the from_iter function should be able to reconstruct it exactly.

Usually, instead of using this function directly, you’ll want to use the @struct decorator for classes, which automatically handles registration for dataclass-like structures. This function is used internally to register the decorated classes. It is also available as an alternative interface for low-level control of flattening/unflattening behavior and static data for custom classes.

The (children, aux_data) convention is intentionally identical to JAX’s pytree node registration, so flattened structree data lines up element for element with the JAX/Flax pytree contract.

See also

struct

Decorator for creating tree-compatible dataclasses

register_dataclass

Register a dataclass as a struct