archimedes.tree.structureΒΆ

archimedes.tree.structure(tree: Tree, is_leaf: Callable[[Any], bool] | None = None) TreeDefΒΆ

Extract the structure of a tree without the leaf values.

This function returns a :py:class:TreeDef that describes the structure of the tree, which can be used with unflatten() to reconstruct a tree with new leaf values.

Parameters:
  • tree (Tree) – A tree whose structure is to be determined.

  • is_leaf (callable, optional) – A function that takes a tree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.

Returns:

treedef – A tree definition that describes the structure of the input tree.

Return type:

TreeDef

Notes

When to use:

  • When you need to extract just the structure of a tree for later use

  • When you want to create a template structure that can be filled with different leaf values

  • When you need to compare the structures of two trees

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Create a structured state
>>> state = {"pos": np.array([0.0, 1.0]), "vel": np.array([2.0, 3.0])}
>>>
>>> # Extract the structure
>>> treedef = arc.tree.structure(state)
>>> print(treedef)
TreeDef({'pos': *, 'vel': *})
>>>
>>> # Create a new state with the same structure but different values
>>> zeros = [np.zeros_like(leaf) for leaf in arc.tree.leaves(state)]
>>> initial_state = arc.tree.unflatten(treedef, zeros)
>>> print(initial_state)
{'pos': array([0., 0.]), 'vel': array([0., 0.])}

See also

flatten

Flatten a tree into a list of leaves and a treedef

unflatten

Reconstruct a tree from leaves and a treedef

ravel

Flatten a tree into a single 1D array