archimedes.tree.flattenΒΆ

archimedes.tree.flatten(x: PyTree, is_leaf: Callable[[Any], bool] | None = None) tuple[list[ArrayLike], PyTreeDef]ΒΆ

Flatten a pytree into a list of leaves and a treedef.

This function recursively traverses the pytree and extracts all leaf values while recording the structure. This is useful when you need to apply operations to all leaf values or convert structured data to a flat representation.

Parameters:
  • x (PyTree) – A pytree to be flattened. A pytree is a nested structure of containers (lists, tuples, dicts, etc) and leaves (arrays, scalars, objects not registered as pytrees).

  • is_leaf (callable, optional) – A function that takes a pytree node as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types are used.

Returns:

  • leaves (list) – A list of all leaf values from the pytree.

  • treedef (PyTreeDef) – A structure definition that can be used to reconstruct the original pytree using unflatten.

Notes

A pytree is defined as a nested structure of:

  • Containers: recognized container types like lists, tuples, and dictionaries

  • Leaves: arrays, scalars, or custom objects not recognized as containers

When to use:

  • When you need to extract all leaf values from a nested structure

  • When you need to convert between structured and flat representations

When converting to/from a flat vector it will typically be more convenient to use ravel() instead of this function.

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Simple pytree with nested containers
>>> data = {"a": np.array([1.0, 2.0]), "b": {"c": np.array([3.0])}}
>>>
>>> # Flatten the pytree
>>> leaves, treedef = arc.tree.flatten(data)
>>> print(leaves)
[array([1., 2.]), array([3.])]
>>>
>>> # Use treedef to reconstruct the pytree
>>> reconstructed = arc.tree.unflatten(treedef, leaves)
>>> print(reconstructed)
{'a': array([1., 2.]), 'b': {'c': array([3.])}}

See also

unflatten

Reconstruct a pytree from leaves and a treedef

leaves

Extract just the leaf values from a pytree

structure

Extract just the structure from a pytree

ravel

Flatten a pytree into a single 1D array