archimedes.tree.mapΒΆ

archimedes.tree.map(f: Callable, tree: T, *rest: tuple[T, ...], is_leaf: Callable[[Any], bool] | None = None) TΒΆ

Apply a function to each leaf in a pytree.

This function traverses the pytree and applies the function f to each leaf, returning a new pytree with the same structure but transformed leaf values. If additional pytrees are provided, the function is applied to corresponding leaves from all pytrees.

Parameters:
  • f (callable) – A function to apply to each leaf. When multiple trees are provided, this function should accept as many arguments as there are trees.

  • tree (Any) – The main pytree whose structure will be followed.

  • *rest (Any) – Additional pytrees with exactly the same structure as the first tree.

  • is_leaf (callable, optional) – A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.

Returns:

mapped_tree – A new pytree with the same structure as tree but with leaf values transformed by function f.

Return type:

Any

Notes

When to use:

  • To transform data in a structured object without changing its structure

  • To perform element-wise operations on corresponding elements of multiple pytrees

  • As an alternative to manually looping through nested structures

  • To apply the same operation to all arrays in a complex model

Raises:

ValueError – If additional pytrees do not have exactly the same structure as the main tree.

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Single pytree example
>>> state = {"pos": np.array([1.0, 2.0]), "vel": np.array([3.0, 4.0])}
>>>
>>> # Double all values
>>> doubled = arc.tree.map(lambda x: x * 2, state)
>>> print(doubled)
{'pos': array([2., 4.]), 'vel': array([6., 8.])}
>>>
>>> # Multiple pytrees example
>>> state1 = {"pos": np.array([1.0, 2.0]), "vel": np.array([3.0, 4.0])}
>>> state2 = {"pos": np.array([5.0, 6.0]), "vel": np.array([7.0, 8.0])}
>>>
>>> # Add corresponding leaves
>>> combined = arc.tree.map(lambda x, y: x + y, state1, state2)
>>> print(combined)
{'pos': array([6., 8.]), 'vel': array([10., 12.])}

See also

flatten

Flatten a pytree into a list of leaves and a treedef

leaves

Extract just the leaf values from a pytree