archimedes.tree.flattenΒΆ

archimedes.tree.flatten(
x: Tree,
is_leaf: Callable[[Any], bool] | None = None,
) tuple[list[ArrayLike], TreeDef]ΒΆ

Flatten a tree into a list of leaves and a treedef.

This function recursively traverses the tree and extracts all leaf values while recording the structure. This is useful when you need to apply operations to all leaf values or convert structured data to a flat representation.

Parameters:
  • x (Tree) – A tree to be flattened. Here, a tree is a nested structure of containers (lists, tuples, dicts, etc) and leaves (arrays, scalars, objects not registered as trees).

  • is_leaf (callable, optional) – A function that takes a tree as input and returns a boolean indicating whether it should be considered a leaf. If not provided, the default leaf types are used.

Returns:

  • leaves (list) – A list of all leaf values from the tree.

  • treedef (TreeDef) – A structure definition that can be used to reconstruct the original tree using unflatten.

Examples

>>> import structree as st
>>> import numpy as np
>>>
>>> data = {"a": np.array([1.0, 2.0]), "b": {"c": np.array([3.0])}}
>>> leaves, treedef = st.flatten(data)
>>> reconstructed = st.unflatten(treedef, leaves)

See also

tree_unflatten

Reconstruct a tree from leaves and a treedef

tree_leaves

Extract just the leaf values from a tree

tree_structure

Extract just the structure from a tree