archimedes.tree.reduceยถ

archimedes.tree.reduce(function: Callable[[V, ArrayLike], V], tree: PyTree, initializer: V, is_leaf: Callable[[Any], bool] | None = None) Vยถ

Reduce a pytree to a single value using a function and initializer.

This function traverses the pytree, applying the reduction function to each leaf and an accumulator, similar to Pythonโ€™s built-in reduce() but operating on all leaves of a pytree.

Parameters:
  • function (callable) โ€“ A function of two arguments: (accumulated_result, leaf_value) that returns a new accumulated result. The function should be commutative and associative to ensure results are independent of traversal order.

  • tree (PyTree) โ€“ A pytree to reduce. A pytree is a nested structure of containers (lists, tuples, dicts) and leaves (arrays or scalars).

  • initializer (Any) โ€“ The initial value for the accumulator.

  • is_leaf (callable, optional) โ€“ A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types (arrays and scalars) are used.

Returns:

result โ€“ The final accumulated value after applying the function to all leaves.

Return type:

Any

Notes

When to use:

  • To compute aggregate values (sum, product, etc.) across all leaf values

  • To collect statistics from a structured model

  • To implement custom reduction operations on complex data structures

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Sum all values in a nested structure
>>> data = {"a": np.array([1.0, 2.0]), "b": {"c": np.array([3.0])}}
>>>
>>> # Compute the sum
>>> def sum_leaf(acc, leaf):
...     return acc + sum(leaf)
>>>
>>> total = arc.tree.reduce(sum_leaf, data, 0.0)
>>> print(total)
6.0
>>>
>>> # Find the maximum value
>>> def max_leaf(acc, leaf):
...     return np.fmax(acc, np.max(leaf))
>>>
>>> maximum = arc.tree.reduce(max_leaf, data, -np.inf)
>>> print(maximum)
3.0

See also

map

Apply a function to each leaf in a pytree

leaves

Extract just the leaf values from a pytree