archimedes.tree.flattenΒΆ
- archimedes.tree.flatten(
- x: Tree,
- is_leaf: Callable[[Any], bool] | None = None,
Flatten a tree into a list of leaves and a treedef.
This function recursively traverses the tree and extracts all leaf values while recording the structure. This is useful when you need to apply operations to all leaf values or convert structured data to a flat representation.
- Parameters:
x (Tree) β A tree to be flattened. Here, a tree is a nested structure of containers (lists, tuples, dicts, etc) and leaves (arrays, scalars, objects not registered as trees).
is_leaf (callable, optional) β A function that takes a tree as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types are used.
- Returns:
leaves (list) β A list of all leaf values from the tree.
treedef (TreeDef) β A structure definition that can be used to reconstruct the original tree using unflatten.
Notes
In this context, a tree is defined as a nested structure of:
Containers: recognized container types like lists, tuples, and dictionaries
Leaves: arrays, scalars, or custom objects not recognized as containers
When to use:
When you need to extract all leaf values from a nested structure
When you need to convert between structured and flat representations
When converting to/from a flat vector it will typically be more convenient to use
ravel()
instead of this function.Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Simple tree with nested containers >>> data = {"a": np.array([1.0, 2.0]), "b": {"c": np.array([3.0])}} >>> >>> # Flatten the tree >>> leaves, treedef = arc.tree.flatten(data) >>> print(leaves) [array([1., 2.]), array([3.])] >>> >>> # Use treedef to reconstruct the tree >>> reconstructed = arc.tree.unflatten(treedef, leaves) >>> print(reconstructed) {'a': array([1., 2.]), 'b': {'c': array([3.])}}