archimedes.tree.structureΒΆ

archimedes.tree.structure(tree: PyTree, is_leaf: Callable[[Any], bool] | None = None) PyTreeDefΒΆ

Extract the structure of a pytree without the leaf values.

This function returns a :py:class:PyTreeDef that describes the structure of the pytree, which can be used with unflatten() to reconstruct a pytree with new leaf values.

Parameters:
  • tree (PyTree) – A pytree whose structure is to be determined.

  • is_leaf (callable, optional) – A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.

Returns:

treedef – A tree definition that describes the structure of the input pytree.

Return type:

PyTreeDef

Notes

When to use:

  • When you need to extract just the structure of a pytree for later use

  • When you want to create a template structure that can be filled with different leaf values

  • When you need to compare the structures of two pytrees

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Create a structured state
>>> state = {"pos": np.array([0.0, 1.0]), "vel": np.array([2.0, 3.0])}
>>>
>>> # Extract the structure
>>> treedef = arc.tree.structure(state)
>>> print(treedef)
PyTreeDef({'pos': *, 'vel': *})
>>>
>>> # Create a new state with the same structure but different values
>>> zeros = [np.zeros_like(leaf) for leaf in arc.tree.leaves(state)]
>>> initial_state = arc.tree.unflatten(treedef, zeros)
>>> print(initial_state)
{'pos': array([0., 0.]), 'vel': array([0., 0.])}

See also

flatten

Flatten a pytree into a list of leaves and a treedef

unflatten

Reconstruct a pytree from leaves and a treedef

ravel

Flatten a pytree into a single 1D array