archimedes.tree.mapΒΆ
- archimedes.tree.map(f: Callable, tree: T, *rest: tuple[T, ...], is_leaf: Callable[[Any], bool] | None = None) T ΒΆ
Apply a function to each leaf in a pytree.
This function traverses the pytree and applies the function
f
to each leaf, returning a new pytree with the same structure but transformed leaf values. If additional pytrees are provided, the function is applied to corresponding leaves from all pytrees.- Parameters:
f (callable) β A function to apply to each leaf. When multiple trees are provided, this function should accept as many arguments as there are trees.
tree (Any) β The main pytree whose structure will be followed.
*rest (Any) β Additional pytrees with exactly the same structure as the first tree.
is_leaf (callable, optional) β A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.
- Returns:
mapped_tree β A new pytree with the same structure as
tree
but with leaf values transformed by functionf
.- Return type:
Notes
When to use:
To transform data in a structured object without changing its structure
To perform element-wise operations on corresponding elements of multiple pytrees
As an alternative to manually looping through nested structures
To apply the same operation to all arrays in a complex model
- Raises:
ValueError β If additional pytrees do not have exactly the same structure as the main tree.
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Single pytree example >>> state = {"pos": np.array([1.0, 2.0]), "vel": np.array([3.0, 4.0])} >>> >>> # Double all values >>> doubled = arc.tree.map(lambda x: x * 2, state) >>> print(doubled) {'pos': array([2., 4.]), 'vel': array([6., 8.])} >>> >>> # Multiple pytrees example >>> state1 = {"pos": np.array([1.0, 2.0]), "vel": np.array([3.0, 4.0])} >>> state2 = {"pos": np.array([5.0, 6.0]), "vel": np.array([7.0, 8.0])} >>> >>> # Add corresponding leaves >>> combined = arc.tree.map(lambda x, y: x + y, state1, state2) >>> print(combined) {'pos': array([6., 8.]), 'vel': array([10., 12.])}