archimedes.tree.mapΒΆ
- archimedes.tree.map(
- f: Callable,
- tree: T,
- *rest: tuple[T, ...],
- is_leaf: Callable[[Any], bool] | None = None,
Apply a function to each leaf in a tree.
Traverses the tree and applies the function
fto each leaf, returning a new tree with the same structure but transformed leaf values. If additional trees are provided, the function is applied to corresponding leaves from all trees.- Parameters:
f (callable) β A function to apply to each leaf. When multiple trees are provided, this function should accept as many arguments as there are trees.
tree (Any) β The main tree whose structure will be followed.
*rest (Any) β Additional trees with exactly the same structure as the first tree.
is_leaf (callable, optional) β A function that takes a tree node as input and returns a boolean indicating whether it should be considered a leaf.
- Returns:
mapped_tree β A new tree with the same structure as
treebut with leaf values transformed by functionf.- Return type:
Any
- Raises:
ValueError β If additional trees do not have exactly the same structure as the main tree.
Examples
>>> import structree as st >>> import numpy as np >>> >>> state = {"pos": np.array([1.0, 2.0]), "vel": np.array([3.0, 4.0])} >>> doubled = st.map(lambda x: x * 2, state)
See also
tree_flattenFlatten a tree into a list of leaves and a treedef
tree_leavesExtract just the leaf values from a tree