archimedes.tree.leavesΒΆ
- archimedes.tree.leaves(tree: PyTree, is_leaf: Callable[[Any], bool] | None = None) list[ArrayLike] ΒΆ
Extract all leaf values from a pytree.
This function traverses the pytree and returns a list of all leaf values without the structure information.
- Parameters:
tree (PyTree) β A pytree from which to extract leaves. A pytree is a nested structure of containers (lists, tuples, dicts, etc) and leaves (arrays, scalars, objects not registered as pytrees).
is_leaf (callable, optional) β A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.
- Returns:
leaves β A list of all leaf values from the pytree.
- Return type:
list
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Create a structured data object >>> data = {"params": {"w": np.array([1.0, 2.0]), "b": 0.5}, ... "state": np.array([3.0, 4.0])} >>> >>> # Extract all leaf values >>> leaf_values = arc.tree.leaves(data) >>> print(leaf_values) [0.5, array([1., 2.]), array([3., 4.])]