archimedes.tree.ravelΒΆ
- archimedes.tree.ravel(pytree: PyTree) tuple[ArrayLike, HashablePartial] ΒΆ
Flatten a pytree to a single 1D array.
This function flattens a pytree into a single 1D array by concatenating all leaf values (which must be arrays or scalars), and provides a function to reconstruct the original structure.
- Parameters:
pytree (Any) β A pytree of arrays and scalars to flatten. A pytree is a nested structure of containers (lists, tuples, dicts) and leaves (arrays or scalars).
- Returns:
flat_array (ndarray) β A 1D array containing all flattened leaf values concatenated together. The dtype is determined by promoting the dtypes of all leaf values. If the input pytree is empty, a 1D empty array of dtype
np.float32
is returned.unravel (callable) β A function that takes a 1D array of the same length as
flat_array
and returns a pytree with the same structure as the inputpytree
, with the values from the 1D array reshaped to match the original leaf shapes.
Notes
When to use:
When you need to convert structured data to a single flat vector for optimization, ODE solving, or other algorithms that work with flat arrays
As a more powerful alternative to
flatten()
when the leaf values themselves need to be flattenedWhen interfacing with external libraries that require flat arrays
The resulting unravel function is specific to the structure of the input pytree and expects an array of exactly the right length.
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Create a structured state >>> state = {"pos": np.array([0.0, 1.0, 2.0]), "vel": np.array([3.0, 4.0, 5.0])} >>> >>> # Flatten to a single vector >>> flat_state, unravel = arc.tree.ravel(state) >>> print(flat_state) [0. 1. 2. 3. 4. 5.] >>> >>> # Modify the flat array >>> flat_state = flat_state * 2 >>> >>> # Reconstruct the original structure with modified values >>> new_state = unravel(flat_state) >>> print(new_state) {'pos': array([0., 2., 4.]), 'vel': array([6., 8., 10.])} >>> >>> # Use with ODE solvers that expect flat vectors >>> @arc.compile >>> def ode_rhs(t, state_flat): ... # Unflatten the state vector to our structured state ... state = unravel(state_flat) ... ... # Compute state derivatives using structured data ... pos_dot = state["vel"] ... vel_dot = -state["pos"] # Simple harmonic oscillator ... ... # Return flattened derivatives ... state_deriv = {"pos": pos_dot, "vel": vel_dot} ... state_deriv_flat, _ = arc.tree.ravel(state_deriv) ... return state_deriv_flat
See also
flatten
Flatten a pytree into a list of leaves and a treedef