archimedes.scan¶
- archimedes.scan(
- func: Callable,
- init_carry: T,
- xs: ArrayLike | None = None,
- length: int | None = None,
Apply a function repeatedly while carrying state between iterations.
Efficiently implements a loop that accumulates state and collects outputs at each iteration. Similar to functional fold/reduce operations but also accumulates the intermediate outputs. This provides a structured way to express iterative algorithms in a functional style that can be efficiently compiled and differentiated.
- Parameters:
func (callable) –
A function with signature
func(carry, x) -> (new_carry, y)to be applied at each loop iteration. The function must:Accept exactly two arguments: the current carry value and loop variable
Return exactly two values: the updated carry value and an output for this step
Return a carry with the same structure as the input carry
init_carry (array_like or Tree) – The initial value of the carry state. Can be a scalar, array, or structured data type. The structure of this value defines what
funcmust return as its first output.xs (array_like, optional) – The values to loop over, with shape
(length, ...). Each value is passed as the second argument tofunc. Required unless length is provided.length (int, optional) – The number of iterations to run. Required if
xsis None. If both are provided,xs.shape[0]must equallength.
- Returns:
final_carry (same type as
init_carry) – The final carry value after all iterations.ys (array) – The stacked outputs from each iteration, with shape
(length, ...).
Notes
When to use this function:
To keep computational graph size manageable for large loops
For implementing recurrent computations (filters, RNNs, etc.)
For iterative numerical methods (e.g., fixed-point iterations)
Conceptual model: Each iteration applies
functo the current carry value and the current loop value:(carry, y) = func(carry, x)The
carryis threaded through all iterations, while eachyoutput is collected. This pattern is common in many iterative algorithms and can be more efficient than explicit Python loops because it creates a single node in the computational graph regardless of the number of iterations.The standard Python equivalent would be:
def scan_equivalent(func, init_carry, xs=None, length=None): if xs is None: xs = range(length) carry = init_carry ys = [] for x in xs: carry, y = func(carry, x) ys.append(y) return carry, np.stack(ys)
However, the compiled
scanis more efficient for long loops because it creates a fixed-size computational graph regardless of loop length.Examples
Basic summation:
>>> import numpy as np >>> import archimedes as arc >>> >>> @arc.compile ... def sum_func(carry, x): ... new_carry = carry + x ... return new_carry, new_carry >>> >>> xs = np.array([1, 2, 3, 4, 5]) >>> final_sum, intermediates = arc.scan(sum_func, 0, xs) >>> print(final_sum) # 15 >>> print(intermediates) # [1, 3, 6, 10, 15]
Implementing a discrete-time IIR filter:
>>> @arc.compile ... def iir_step(state, x): ... # Simple first-order IIR filter: y[n] = 0.9*y[n-1] + 0.1*x[n] ... new_state = 0.9 * state + 0.1 * x ... return new_state, new_state >>> >>> # Apply to a step input >>> input_signal = np.ones(50) >>> initial_state = 0.0 >>> final_state, filtered = arc.scan(iir_step, initial_state, input_signal)
Implementing Euler’s method for ODE integration:
>>> @arc.compile ... def euler_step(state, t): ... # Simple harmonic oscillator: d²x/dt² = -x ... dt = 0.001 ... x, v = state ... new_x = x + dt * v ... new_v = v - dt * x ... return (new_x, new_v), new_x >>> >>> ts = np.linspace(0, 1.0, 1001) >>> initial_state = (1.0, 0.0) # x=1, v=0 >>> final_state, trajectory = arc.scan(euler_step, initial_state, ts)
See also
jax.lax.scanJAX equivalent function
arc.treeModule for working with structured data in scan loops