Working with PyTrees¶
PyTrees in Archimedes provide a way to work with structured data in numerical algorithms that typically expect flat vectors. This page explains what PyTrees are, how to use them in your code, and how to create custom structures that work seamlessly with Archimedes functions.
The Archimedes concept of a PyTree is borrowed from JAX, while the pytree_node
decorator produces composable classes that work similarly to PyTorch Modules.
You may also want to scan these documentation pages to get ideas about what these structures are and how they are used.
What is a PyTree?¶
A PyTree is any nested structure of “containers” (dictionaries, lists, tuples, NamedTuples) and “leaves” (arrays or scalars). Archimedes automatically recognizes built-in Python containers as PyTree nodes, while arrays and scalars are treated as leaves.
Examples of valid PyTrees:
# A list of scalars
[1.0, 2.0, 3.0]
# A dictionary of arrays
{"position": np.array([0.0, 1.0, 2.0]), "velocity": np.array([3.0, 4.0, 5.0])}
# A nested structure with multiple container types
(np.array([1.0, 2.0]), [3.0, 4.0], {"value": 5.0})
Flattening and Unflattening¶
The primary operations with PyTrees are:
Flattening: Converting a structured PyTree into a flat vector
Unflattening: Restoring a flat vector to its original structure
import archimedes as arc
import numpy as np
# Create a structured state
state = {"pos": np.array([0.0, 1.0, 2.0]), "vel": np.array([3.0, 4.0, 5.0])}
# Flatten to a vector
flat_state, unravel = arc.tree.ravel(state)
print(flat_state) # array([0., 1., 2., 3., 4., 5.])
# Restore the original structure
restored_state = unravel(flat_state)
print(restored_state) # {'pos': array([0., 1., 2.]), 'vel': array([3., 4., 5.])}
This pattern is essential when working with:
ODE solvers that require state vectors
Optimization algorithms operating on parameter vectors
Root-finding methods that expect flat systems
Custom PyTree Types¶
For more complex models, create custom classes that work as PyTrees using the struct.pytree_node
decorator:
from archimedes.tree import struct
@struct.pytree_node
class VehicleState:
position: np.ndarray # [x, y, z]
velocity: np.ndarray # [vx, vy, vz]
attitude: np.ndarray # quaternion [qx, qy, qz, qw]
angular_velocity: np.ndarray # [wx, wy, wz]
# Create an instance
state = VehicleState(
position=np.zeros(3),
velocity=np.zeros(3),
attitude=np.array([0, 0, 0, 1]),
angular_velocity=np.zeros(3)
)
# Flatten to a vector
flat_state, unravel = arc.tree.ravel(state)
# Use in compiled functions
@arc.compile
def dynamics(state, control, dt=0.1):
# Access fields naturally
new_position = state.position + dt * state.velocity
# ...other calculations...
return VehicleState(
position=new_position,
# ...other updated fields...
)
The struct.pytree_node
decorator automatically registers your class with Archimedes’ PyTree system, combining the benefits of Python dataclasses with PyTree functionality.
Custom nodes: advanced usage¶
Since the pytree_node
decorator converts your class into a standard (frozen) Python dataclass, many typical dataclass considerations carry over directly to custom nodes.
For instance, you should typically avoid implementing __init__
yourself (as this is constructed automatically by the dataclass), but you can implement __post_init__
for custom initializations.
In addition, you can apply the usual field
to any field defined for the node.
The struct
module provides its own wrapper of field
, which extends it with the ability to label a field as static
.
Among other things, this means that it should not be included when translating to/from a flat vector.
These custom nodes are otherwise normal Python classes, so you can define methods on them as usual.
Here is an expanded example with some advanced features:
import numpy as np
import archimedes as arc
from archimedes import struct
@struct.pytree_node
class Rocket:
# Dynamic variables (included in flattening)
h: float # height in meters
v: float # velocity in m/s
m: float # Current mass in kg
# Static parameters (excluded from flattening)
thrust: float = struct.field(static=True, default=10000.0) # Thrust in Newtons
isp: float = struct.field(static=True, default=300.0) # Specific impulse in seconds
def __post_init__(self):
# Validate inputs
if self.m <= 0:
raise ValueError("Mass must be positive")
# Create a rocket state
rocket = Rocket(
h=0.0,
v=0.0,
m=1000.0,
thrust=15000.0, # Override the default
)
print(rocket) # Rocket(h=0.0, v=0.0, m=1000.0, thrust=15000.0, isp=300.0)
# Flatten to vector - note that static fields are excluded
flat_state, unravel = arc.tree.ravel(rocket)
print(f"Flat state shape: {flat_state.shape}") # (3,) for height + velocity + mass
# Modify the flat state
flat_state[0] += 10 # Increase height by 10 meters
# Unravel back to object - static fields are restored
new_rocket = unravel(flat_state)
print(new_rocket) # Rocket(h=10.0, v=0.0, m=1000.0, thrust=15000.0, isp=300.0)
You can also nest custom PyTree nodes within each other and define special methods like __call__
, giving you the ability to create modular and reusable model components.
For example, if we wanted to simulate a rendezvous between our rocket and the ISS, we could create another PyTree node Satellite
and the combined state of our system could be defined by a NamedTuple
, making the entire composite state a valid PyTree:
from typing import NamedTuple
class Satellite(NamedTuple):
pos: np.ndarray
vel: np.ndarray
class RendezvousState(NamedTuple):
rocket: Rocket
satellite: Satellite
satellite = Satellite(pos=np.zeros(3), vel=np.ones(3))
state = RendezvousState(rocket, satellite)
flat_state, unravel = arc.tree.ravel(state)
print(flat_state.shape) # (9,): three from rocket and six from satellite
Example: Pendulum Simulation Using PyTrees¶
Here’s how PyTrees simplify ODE solving with structured data:
import numpy as np
import archimedes as arc
# Define a custom PyTree node with dynamics method
@arc.struct.pytree_node
class PendulumState:
theta: float # angle
omega: float # angular velocity
@classmethod
def dynamics(cls, t, state):
g, L = 9.81, 1.0
# Calculate derivatives
theta_t = state.omega
omega_t = -(g/L) * np.sin(state.theta)
# Return in the same structure
return cls(theta=theta_t, omega=omega_t)
# Initial state and simulation parameters
initial_state = PendulumState(theta=np.pi/4, omega=0.0)
t_span = (0.0, 10.0)
t_eval = np.linspace(*t_span, 100)
# Convert to flat vector for solver
x0, unravel = arc.tree.ravel(initial_state)
# Create flat dynamics wrapper
@arc.compile
def flat_dynamics(t, x):
state = unravel(x) # Unflatten to
derivatives = state.dynamics(t, state)
dx, _ = arc.tree.ravel(derivatives)
return dx # Return the flattened state
# Solve the ODE
solution = arc.odeint(flat_dynamics, t_span=t_span, x0=x0, t_eval=t_eval)
# Convert results back to structured form and unpack
states = [unravel(x) for x in solution.T]
theta = np.array([state.theta for state in states])
omega = np.array([state.omega for state in states])
Current Limitations¶
The following limitations will be resolved with further development:
Postprocessing: The
unravel
functions assume that their arguments have the same shape as the original PyTree, leading to somewhat complicated unpacking operations, particularly for ODE solutions like the previous example.No automatic conversion: Functions like
odeint
andminimize
don’t automatically convert all their arguments to PyTrees, meaning you have to manually construct wrapper functions likeflat_dynamics
above.
If you encounter what looks like a bug with PyTrees or another limitation, please file an issue!
Best Practices and Tips¶
Structure Consistency: When unraveling, the flat array length must match the original PyTree structure
Immutability: Treat PyTrees as immutable, creating new instances rather than modifying in place
Method Support: Custom PyTree classes can include methods for operations on your data
Performance: PyTree flattening/unflattening does reshaping at the tracing stage, meaning that it has minimal runtime overhead compared to typical numerical operations
Type Support: PyTrees work with both NumPy arrays and Archimedes symbolic arrays
For more advanced PyTree operations, explore the archimedes.tree
module.