archimedes.vmap¶
- archimedes.vmap(func: Callable, in_axes: int | None | tuple[int | None, ...] = 0, out_axes: int = 0, name: str | None = None) Callable ¶
Vectorize a function along specified argument axes.
The vmap transformation takes a function that operates on individual elements and transforms it into one that operates on batches of elements in a vectorized manner. This enables efficient computation without writing explicit loops or broadcasting logic.
- Parameters:
func (callable) – Function to be vectorized. The function can accept ordinary NumPy arrays or PyTree-structured data.
in_axes (int, None, or tuple of ints/None, optional) –
Specifies which axis of each input argument should be mapped over.
int: Use the same axis for all arguments (e.g., 0 for the first dimension)
None: Don’t map this argument (broadcast it to all mapped elements)
tuple: Specify a different axis for each argument
Default is 0 (map over the first axis of each argument).
out_axes (int, optional) – Specifies where the mapped axis should appear in the output. Default is 0 (mapped axis is the first dimension of the output).
name (str, optional) – Name for the transformed function. If None, derives a name from the original function.
- Returns:
vectorized_func – A function with the same signature as func that operates on batches of inputs.
- Return type:
callable
Notes
When to use this function:
When you need to apply the same operation to many inputs efficiently
To convert a single-example function into one that handles batches
To selectively vectorize over some arguments while broadcasting others
To “unflatten” tree-structured data by mapping the unravel function
Conceptual model:
vmap transforms functions to operate along array axes. For example, a function f(x) that takes a vector and returns a scalar can be transformed into one that takes a batch of vectors (an array) and returns a batch of scalars (a vector), without explicitly writing loops.
Each argument can be mapped differently: - Mapped arguments (in_axes is an int): Batched processing along the specified axis - Broadcasted arguments (in_axes is None): Same value used for all batch elements
The vectorized function ensures that all mapped arguments have the same size along their mapped dimensions.
Examples
Basic vectorization of a dot product:
>>> import numpy as np >>> import archimedes as arc >>> >>> def dot(a, b): ... return np.dot(a, b) >>> >>> # Vectorize to compute multiple dot products at once >>> batched_dot = arc.vmap(dot) >>> >>> # Input: batch of vectors (3 vectors of length 2) >>> x = np.array([[1, 2], [3, 4], [5, 6]]) >>> y = np.array([[7, 8], [9, 10], [11, 12]]) >>> >>> # Output: batch of scalars (3 dot products) >>> batched_dot(x, y) array([ 23, 67, 127])
Working with structured data (PyTrees):
>>> from archimedes import struct >>> >>> @struct.pytree_node >>> class Particle: ... x: np.ndarray ... v: np.ndarray >>> >>> def update(p, dt): ... return p.replace(x=p.x + dt * p.v) >>> >>> # Vectorize to update multiple particles at once >>> map_update = arc.vmap(update, in_axes=(0, None)) >>> >>> # Batch of 10 particles >>> x = np.random.randn(10, 3) # 10 particles in 3D space >>> v = np.random.randn(10, 3) >>> particles = Particle(x=x, v=v) >>> >>> # Update all 10 particles at once >>> new_particles = map_update(particles)
See also
scan
Transform that applies a function sequentially to array elements