archimedes.tree.unflattenยถ
- archimedes.tree.unflatten(treedef: PyTreeDef, xs: list[ArrayLike]) PyTree ยถ
Reconstruct a pytree from a list of leaves and a treedef.
This function is the inverse of
flatten()
. It takes a list of leaf values and a tree definition, and reconstructs the original pytree structure.- Parameters:
treedef (PyTreeDef) โ A tree definition, typically produced by
flatten()
orstructure()
.xs (list[ArrayLike]) โ A list of leaf values to be placed in the reconstructed pytree. The length must match the number of leaves in
treedef
.
- Returns:
pytree โ The reconstructed pytree with the same structure as defined by treedef and with leaf values from
xs
.- Return type:
PyTree
Notes
When converting to/from a flat vector it will typically be more convenient to use
ravel()
instead of this function.- Raises:
ValueError โ If the number of leaves in
xs
doesnโt match the expected number intreedef
.
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> # Original pytree >>> data = {"positions": np.array([1.0, 2.0]), "velocities": np.array([3.0, 4.0])} >>> >>> # Flatten the pytree >>> leaves, treedef = arc.tree.flatten(data) >>> >>> # Transform the leaves (e.g., multiply by 2) >>> new_leaves = [leaf * 2 for leaf in leaves] >>> >>> # Reconstruct the pytree with the new leaves >>> new_data = arc.tree.unflatten(treedef, new_leaves) >>> print(new_data) {'positions': array([2., 4.]), 'velocities': array([6., 8.])}