archimedes.tree.unflattenยถ

archimedes.tree.unflatten(treedef: PyTreeDef, xs: list[ArrayLike]) PyTreeยถ

Reconstruct a pytree from a list of leaves and a treedef.

This function is the inverse of flatten(). It takes a list of leaf values and a tree definition, and reconstructs the original pytree structure.

Parameters:
  • treedef (PyTreeDef) โ€“ A tree definition, typically produced by flatten() or structure().

  • xs (list[ArrayLike]) โ€“ A list of leaf values to be placed in the reconstructed pytree. The length must match the number of leaves in treedef.

Returns:

pytree โ€“ The reconstructed pytree with the same structure as defined by treedef and with leaf values from xs.

Return type:

PyTree

Notes

When converting to/from a flat vector it will typically be more convenient to use ravel() instead of this function.

Raises:

ValueError โ€“ If the number of leaves in xs doesnโ€™t match the expected number in treedef.

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> # Original pytree
>>> data = {"positions": np.array([1.0, 2.0]), "velocities": np.array([3.0, 4.0])}
>>>
>>> # Flatten the pytree
>>> leaves, treedef = arc.tree.flatten(data)
>>>
>>> # Transform the leaves (e.g., multiply by 2)
>>> new_leaves = [leaf * 2 for leaf in leaves]
>>>
>>> # Reconstruct the pytree with the new leaves
>>> new_data = arc.tree.unflatten(treedef, new_leaves)
>>> print(new_data)
{'positions': array([2., 4.]), 'velocities': array([6., 8.])}

See also

flatten

Flatten a pytree into a list of leaves and a treedef

structure

Extract just the structure from a pytree

ravel

Flatten a pytree into a single 1D array