Hierarchical Design Patterns¶
This page covers best practices and design patterns for creating composable dynamic systems using Archimedes.
By leveraging the pytree_node
decorator, you can create modular components that can be combined into complex hierarchical models while maintaining clean, organized code.
However, the recommendations in this guide are strictly suggestions; you can design your models and workflows however you wish.
Core Concepts¶
Pytree Nodes for Structured States¶
Dynamical systems often have state variables that benefit from logical grouping. Using tree-structured representations allows you to:
Group related state variables together
Create nested hierarchies that mirror the physical structure of your system
Maintain clean interfaces between subsystems
Flatten and unflatten states automatically for ODE solvers
Design Patterns¶
Some recommended patterns for building modular dynamics components in Archimedes are:
Modular Components: Create a
pytree_node
for each logical system componentHierarchical Parameters: Add model parameters as fields in the pytree nodes
Nested State Classes: Define a
State
class inside each model componentDynamics Methods: Implement
dynamics(self, t, state)
methods that return state derivativesCompositional Models: Build larger models by combining smaller components
Basic Component Pattern¶
Here’s a basic example of using these patterns to creating a modular dynamical system component:
import os
import matplotlib.pyplot as plt
import numpy as np
import archimedes as arc
from archimedes import struct
THEME = os.environ.get("ARCHIMEDES_THEME", "dark")
arc.theme.set_theme(THEME)
@struct.pytree_node
class Oscillator:
"""A basic mass-spring-damper component."""
# Define model parameters as fields in the PyTree node
m: float # Mass
k: float # Spring constant
b: float # Damping constant
# Define a nested State class as another PyTree node
@struct.pytree_node
class State:
"""State variables for the mass-spring-damper system."""
x: np.ndarray
v: np.ndarray
def dynamics(self, t, state: State, f_ext=0.0):
"""Compute the time derivatives of the state variables."""
# Compute derivatives
f_net = f_ext - self.k * state.x - self.b * state.v
# Return state derivatives in the same structure
return self.State(
x=state.v,
v=f_net / self.m,
)
system = Oscillator(m=1.0, k=1.0, b=0.1)
x0 = system.State(x=1.0, v=0.0)
x0
Oscillator.State(x=1.0, v=0.0)
For such a simple system, the advantages to this design are relatively limited, but because these nodes can be nested within each other, it can be a useful way to organize states, parameters, and functions associated with complex models.
Working with PyTree models¶
Many functions like ODE solvers expect to work with flat vectors. PyTree utilities in Archimedes make conversion to and from flat vectors easy. For example, we can “ravel” a PyTree-structured state to a vector and “unravel” back to the original state:
x0_flat, unravel = arc.tree.ravel(x0)
print(x0_flat)
print(unravel(x0_flat))
[1. 0.]
Oscillator.State(x=array(1.), v=array(0.))
The unravel
function created by tree.ravel
is specific to the original PyTree argument, so it can be used within ODE functions, for example:
@arc.compile
def ode_rhs(t, state_flat, system):
# Unflatten the state vector to our structured state
state = unravel(state_flat)
# Compute state derivatives using model dynamics
state_deriv = system.dynamics(t, state)
# Flatten derivatives back to a vector
state_deriv_flat, _ = arc.tree.ravel(state_deriv)
return state_deriv_flat
# Solve the ODE
t_span = (0.0, 10.0)
t_eval = np.linspace(*t_span, 100)
solution_flat = arc.odeint(
ode_rhs,
t_span=t_span,
x0=x0_flat,
t_eval=t_eval,
args=(system,),
)
Since the model itself is also a PyTree, we can also apply ravel
directly to it, giving us a flat vector of the parameters defined as fields:
p_flat, unravel_system = arc.tree.ravel(system)
print(p_flat) # [1. 1. 0.1]
[1. 1. 0.1]
This is useful for applications in optimization and parameter estimation.
Complete Example: Coupled Oscillators¶
Larger systems can be built by composing multiple components together. Let’s build a system of coupled oscillators to demonstrate these patterns.
@struct.pytree_node
class CoupledOscillators:
"""A system of two coupled oscillators."""
osc1: Oscillator
osc2: Oscillator
coupling_constant: float
@struct.pytree_node
class State:
"""Combined state of both oscillators."""
osc1: Oscillator.State
osc2: Oscillator.State
def dynamics(self, t, state):
"""Compute dynamics of the coupled system."""
# Extract states
x1 = state.osc1.x
x2 = state.osc2.x
# Compute equal and opposite coupling force
f_ext = self.coupling_constant * (x2 - x1)
return self.State(
osc1=self.osc1.dynamics(t, state.osc1, f_ext),
osc2=self.osc2.dynamics(t, state.osc2, -f_ext),
)
# Create a coupled oscillator system
system = CoupledOscillators(
osc1=Oscillator(m=1.0, k=4.0, b=0.1),
osc2=Oscillator(m=1.5, k=2.0, b=0.2),
coupling_constant=0.5,
)
# Create initial state
x0 = system.State(
osc1=Oscillator.State(x=1.0, v=0.0),
osc2=Oscillator.State(x=-0.5, v=0.0),
)
# Flatten the state for ODE solver
x0_flat, state_unravel = arc.tree.ravel(x0)
# ODE function that works with flat arrays
@arc.compile
def ode_rhs(t, state_flat, system):
state = state_unravel(state_flat)
state_deriv = system.dynamics(t, state)
state_deriv_flat, _ = arc.tree.ravel(state_deriv)
return state_deriv_flat
# Solve the system
t_span = (0.0, 20.0)
t_eval = np.linspace(*t_span, 200)
sol_flat = arc.odeint(
ode_rhs,
t_span=t_span,
x0=x0_flat,
t_eval=t_eval,
args=(system,),
)
# Postprocessing: create a "vectorized map" of the unravel
# function to map back to the original tree-structured state
sol = arc.vmap(state_unravel, in_axes=1)(sol_flat)
# Plot the results
plt.figure(figsize=(7, 2))
plt.plot(t_eval, sol.osc1.x, label="Oscillator 1")
plt.plot(t_eval, sol.osc2.x, label="Oscillator 2")
plt.plot(t_eval, -sol.osc2.x, label="Oscillator 2")
plt.xlabel("Time")
plt.ylabel("Position")
plt.title("Coupled Oscillators")
plt.legend()
plt.grid(True)
plt.show()


Summary¶
The recommended approach to building hierarchical and modular dynamical systems in Archimedes follows these key patterns:
Use
@struct.pytree_node
to define structured component classesCreate nested
State
classes to organize state variablesImplement
dynamics
methods that compute state derivativesCompose larger systems from smaller components
Add helper methods to simplify simulation and analysis
Other best practices include:
Consistent Interfaces: Keep the
dynamics(self, t, state, *args)
method signature consistent across all componentsImmutable States: Always return new state objects instead of modifying existing ones
Physical Units: Document physical units in comments or docstrings
Input Validation: Add validation in constructors to catch errors early
Meaningful Names: Use descriptive names that reflect physical components, or consistent pseudo-mathematical notation like the monogram convention
Domain Decomposition: Decompose complex systems into logical components (mechanical, electrical, etc.)
Structured Parameters: Define physical parameters as fields in the PyTree nodes, and use the
struct.field(static=True)
annotation to mark configuration variables.
These patterns enable clean, organized, and reusable model components while leveraging Archimedes’ PyTree functionality to handle the conversion between structured and flat representations needed by ODE solvers.