archimedes.tree.mapΒΆ

archimedes.tree.map(
f: Callable,
tree: T,
*rest: tuple[T, ...],
is_leaf: Callable[[Any], bool] | None = None,
) TΒΆ

Apply a function to each leaf in a tree.

Traverses the tree and applies the function f to each leaf, returning a new tree with the same structure but transformed leaf values. If additional trees are provided, the function is applied to corresponding leaves from all trees.

Parameters:
  • f (callable) – A function to apply to each leaf. When multiple trees are provided, this function should accept as many arguments as there are trees.

  • tree (Any) – The main tree whose structure will be followed.

  • *rest (Any) – Additional trees with exactly the same structure as the first tree.

  • is_leaf (callable, optional) – A function that takes a tree node as input and returns a boolean indicating whether it should be considered a leaf.

Returns:

mapped_tree – A new tree with the same structure as tree but with leaf values transformed by function f.

Return type:

Any

Raises:

ValueError – If additional trees do not have exactly the same structure as the main tree.

Examples

>>> import structree as st
>>> import numpy as np
>>>
>>> state = {"pos": np.array([1.0, 2.0]), "vel": np.array([3.0, 4.0])}
>>> doubled = st.map(lambda x: x * 2, state)

See also

tree_flatten

Flatten a tree into a list of leaves and a treedef

tree_leaves

Extract just the leaf values from a tree