archimedes.tree.struct.pytree_node¶

archimedes.tree.struct.pytree_node(cls: T | None = None, **kwargs) T | Callable¶

Decorator to convert a class into a frozen dataclass registered as a pytree node.

This decorator creates a dataclass that can be seamlessly used with Archimedes’ pytree functions. The class will be registered with the pytree system, allowing its instances to be flattened, mapped over, and transformed while preserving its structure.

Parameters:
  • cls (type, optional) – The class to convert into a pytree

  • **kwargs (dict) – Additional keyword arguments passed to dataclasses.dataclass(). By default, frozen=True is set unless explicitly overridden.

Returns:

decorated_class – The decorated class, now a frozen dataclass registered as a pytree node.

Return type:

type

Notes

When to use:

  • To create structured data objects for use in Archimedes models and simulations

  • To define state containers that work with pytree-based transformations

  • To create modular, composable model components with clear interfaces

  • To define parameter structures for optimization problems

The “frozen” attribute makes the class immutable, meaning that once an instance is created, its fields cannot be modified. This is useful for ensuring that the state of the object remains consistent during operations. The replace() method allows you to create modified copies of the object with new values for specific fields.

Fields are automatically classified as either “data” (dynamic values that change during operations) or “static” (configuration parameters). By default, all fields are treated as data unless marked with field(static=True).

The decorated class:

  • Is frozen (immutable) by default

  • Has a replace() method for creating modified copies

  • Will be properly handled by tree.flatten(), tree.map(), etc.

  • Can be nested within other pytree nodes

Examples

>>> import archimedes as arc
>>> import numpy as np
>>>
>>> @arc.struct.pytree_node
>>> class Vehicle:
...     # Dynamic state variables (included in transformations)
...     position: np.ndarray
...     velocity: np.ndarray
...
...     # Static configuration parameters (preserved during transformations)
...     mass: float = arc.struct.field(static=True, default=1000.0)
...     drag_coef: float = arc.struct.field(static=True, default=0.3)
...
...     def kinetic_energy(self):
...         return 0.5 * self.mass * np.sum(self.velocity**2)
>>>
>>> # Create an instance
>>> car = Vehicle(
...     position=np.array([0.0, 0.0]),
...     velocity=np.array([10.0, 0.0]),
... )
>>>
>>> # Create a modified copy
>>> car2 = car.replace(position=np.array([5.0, 0.0]))
>>>
>>> # Apply a transformation (only to dynamic fields)
>>> scaled = arc.tree.map(lambda x: x * 2, car)
>>> print(scaled.position)    # [0. 0.] -> [0. 0.]
>>> print(scaled.velocity)    # [10. 0.] -> [20. 0.]
>>> print(scaled.mass)        # 1000.0 (unchanged)
>>>
>>> # Nested pytree nodes
>>> @arc.struct.pytree_node
>>> class System:
...     vehicle1: Vehicle
...     vehicle2: Vehicle
...
...     def total_energy(self):
...         return self.vehicle1.kinetic_energy() + self.vehicle2.kinetic_energy()
>>>
>>> system = System(car, car2)
>>> # This transformation applies to all dynamic fields in the entire hierarchy
>>> scaled_system = arc.tree.map(lambda x: x * 0.5, system)

See also

field

Define fields with pytree-specific metadata