archimedes.tree.reduceยถ
- archimedes.tree.reduce(function: Callable[[V, ArrayLike], V], tree: PyTree, initializer: V, is_leaf: Callable[[Any], bool] | None = None) V ยถ
Reduce a pytree to a single value using a function and initializer.
This function traverses the pytree, applying the reduction function to each leaf and an accumulator, similar to Pythonโs built-in
reduce()
but operating on all leaves of a pytree.- Parameters:
function (callable) โ A function of two arguments: (accumulated_result, leaf_value) that returns a new accumulated result. The function should be commutative and associative to ensure results are independent of traversal order.
tree (PyTree) โ A pytree to reduce. A pytree is a nested structure of containers (lists, tuples, dicts) and leaves (arrays or scalars).
initializer (Any) โ The initial value for the accumulator.
is_leaf (callable, optional) โ A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.
- Returns:
result โ The final accumulated value after applying the function to all leaves.
- Return type:
Notes
When to use:
To compute aggregate values (sum, product, etc.) across all leaf values
To collect statistics from a structured model
To implement custom reduction operations on complex data structures
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Sum all values in a nested structure >>> data = {"a": np.array([1.0, 2.0]), "b": {"c": np.array([3.0])}} >>> >>> # Compute the sum >>> def sum_leaf(acc, leaf): ... return acc + sum(leaf) >>> >>> total = arc.tree.reduce(sum_leaf, data, 0.0) >>> print(total) 6.0 >>> >>> # Find the maximum value >>> def max_leaf(acc, leaf): ... return np.fmax(acc, np.max(leaf)) >>> >>> maximum = arc.tree.reduce(max_leaf, data, -np.inf) >>> print(maximum) 3.0