archimedes.tree.structureΒΆ
- archimedes.tree.structure(tree: PyTree, is_leaf: Callable[[Any], bool] | None = None) PyTreeDef ΒΆ
Extract the structure of a pytree without the leaf values.
This function returns a :py:class:PyTreeDef that describes the structure of the pytree, which can be used with
unflatten()
to reconstruct a pytree with new leaf values.- Parameters:
tree (PyTree) β A pytree whose structure is to be determined.
is_leaf (callable, optional) β A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.
- Returns:
treedef β A tree definition that describes the structure of the input pytree.
- Return type:
PyTreeDef
Notes
When to use:
When you need to extract just the structure of a pytree for later use
When you want to create a template structure that can be filled with different leaf values
When you need to compare the structures of two pytrees
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Create a structured state >>> state = {"pos": np.array([0.0, 1.0]), "vel": np.array([2.0, 3.0])} >>> >>> # Extract the structure >>> treedef = arc.tree.structure(state) >>> print(treedef) PyTreeDef({'pos': *, 'vel': *}) >>> >>> # Create a new state with the same structure but different values >>> zeros = [np.zeros_like(leaf) for leaf in arc.tree.leaves(state)] >>> initial_state = arc.tree.unflatten(treedef, zeros) >>> print(initial_state) {'pos': array([0., 0.]), 'vel': array([0., 0.])}