archimedes.tree.leavesΒΆ

archimedes.tree.leaves(tree: PyTree, is_leaf: Callable[[Any], bool] | None = None) list[ArrayLike]ΒΆ

Extract all leaf values from a pytree.

This function traverses the pytree and returns a list of all leaf values without the structure information.

Parameters:
  • tree (PyTree) – A pytree from which to extract leaves. A pytree is a nested structure of containers (lists, tuples, dicts, etc) and leaves (arrays, scalars, objects not registered as pytrees).

  • is_leaf (callable, optional) – A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.

Returns:

leaves – A list of all leaf values from the pytree.

Return type:

list

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Create a structured data object
>>> data = {"params": {"w": np.array([1.0, 2.0]), "b": 0.5},
...         "state": np.array([3.0, 4.0])}
>>>
>>> # Extract all leaf values
>>> leaf_values = arc.tree.leaves(data)
>>> print(leaf_values)
[0.5, array([1., 2.]), array([3., 4.])]

See also

flatten

Flatten a pytree into a list of leaves and a treedef

ravel

Flatten a pytree into a single 1D array

map

Apply a function to each leaf in a pytree