archimedes.tree.struct.pytree_node¶
- archimedes.tree.struct.pytree_node(cls: T | None = None, **kwargs) T | Callable ¶
Decorator to convert a class into a frozen dataclass registered as a pytree node.
This decorator creates a dataclass that can be seamlessly used with Archimedes’ pytree functions. The class will be registered with the pytree system, allowing its instances to be flattened, mapped over, and transformed while preserving its structure.
- Parameters:
cls (type, optional) – The class to convert into a pytree
**kwargs (dict) – Additional keyword arguments passed to dataclasses.dataclass(). By default,
frozen=True
is set unless explicitly overridden.
- Returns:
decorated_class – The decorated class, now a frozen dataclass registered as a pytree node.
- Return type:
type
Notes
When to use:
To create structured data objects for use in Archimedes models and simulations
To define state containers that work with pytree-based transformations
To create modular, composable model components with clear interfaces
To define parameter structures for optimization problems
The “frozen” attribute makes the class immutable, meaning that once an instance is created, its fields cannot be modified. This is useful for ensuring that the state of the object remains consistent during operations. The
replace()
method allows you to create modified copies of the object with new values for specific fields.Fields are automatically classified as either “data” (dynamic values that change during operations) or “static” (configuration parameters). By default, all fields are treated as data unless marked with
field(static=True)
.The decorated class:
Is frozen (immutable) by default
Has a
replace()
method for creating modified copiesWill be properly handled by
tree.flatten()
,tree.map()
, etc.Can be nested within other pytree nodes
Examples
>>> import archimedes as arc >>> import numpy as np >>> >>> @arc.struct.pytree_node >>> class Vehicle: ... # Dynamic state variables (included in transformations) ... position: np.ndarray ... velocity: np.ndarray ... ... # Static configuration parameters (preserved during transformations) ... mass: float = arc.struct.field(static=True, default=1000.0) ... drag_coef: float = arc.struct.field(static=True, default=0.3) ... ... def kinetic_energy(self): ... return 0.5 * self.mass * np.sum(self.velocity**2) >>> >>> # Create an instance >>> car = Vehicle( ... position=np.array([0.0, 0.0]), ... velocity=np.array([10.0, 0.0]), ... ) >>> >>> # Create a modified copy >>> car2 = car.replace(position=np.array([5.0, 0.0])) >>> >>> # Apply a transformation (only to dynamic fields) >>> scaled = arc.tree.map(lambda x: x * 2, car) >>> print(scaled.position) # [0. 0.] -> [0. 0.] >>> print(scaled.velocity) # [10. 0.] -> [20. 0.] >>> print(scaled.mass) # 1000.0 (unchanged) >>> >>> # Nested pytree nodes >>> @arc.struct.pytree_node >>> class System: ... vehicle1: Vehicle ... vehicle2: Vehicle ... ... def total_energy(self): ... return self.vehicle1.kinetic_energy() + self.vehicle2.kinetic_energy() >>> >>> system = System(car, car2) >>> # This transformation applies to all dynamic fields in the entire hierarchy >>> scaled_system = arc.tree.map(lambda x: x * 0.5, system)
See also
field
Define fields with pytree-specific metadata