archimedes.tree.ravelΒΆ

archimedes.tree.ravel(tree: Tree) tuple[ArrayLike, HashablePartial]ΒΆ

Flatten a tree to a single 1D array.

This function flattens a tree into a single 1D array by concatenating all leaf values (which must be arrays or scalars), and provides a function to reconstruct the original structure.

Parameters:

tree (Any) – A tree of arrays and scalars to flatten. A tree is a nested structure of containers (lists, tuples, dicts) and leaves (arrays or scalars).

Returns:

  • flat_array (ndarray) – A 1D array containing all flattened leaf values concatenated together. The dtype is determined by promoting the dtypes of all leaf values. If the input tree is empty, a 1D empty array of dtype np.float32 is returned. The array type follows the leaves: a tree of symbolic leaves produces a symbolic flat array.

  • unravel (callable) – A function that takes a 1D array of the same length as flat_array and returns a tree with the same structure as the input tree, with the values from the 1D array reshaped to match the original leaf shapes.

Examples

>>> import structree as st
>>> import numpy as np
>>>
>>> state = {"pos": np.array([0.0, 1.0, 2.0]), "vel": np.array([3.0, 4.0, 5.0])}
>>> flat_state, unravel = st.ravel(state)
>>> new_state = unravel(flat_state * 2)

See also

tree_flatten

Flatten a tree into a list of leaves and a treedef