archimedes.vmap¶

archimedes.vmap(func: Callable, in_axes: int | None | tuple[int | None, ...] = 0, out_axes: int = 0, name: str | None = None) Callable¶

Vectorize a function along specified argument axes.

The vmap transformation takes a function that operates on individual elements and transforms it into one that operates on batches of elements in a vectorized manner. This enables efficient computation without writing explicit loops or broadcasting logic.

Parameters:
  • func (callable) – Function to be vectorized. The function can accept ordinary NumPy arrays or PyTree-structured data.

  • in_axes (int, None, or tuple of ints/None, optional) –

    Specifies which axis of each input argument should be mapped over.

    • int: Use the same axis for all arguments (e.g., 0 for the first dimension)

    • None: Don’t map this argument (broadcast it to all mapped elements)

    • tuple: Specify a different axis for each argument

    Default is 0 (map over the first axis of each argument).

  • out_axes (int, optional) – Specifies where the mapped axis should appear in the output. Default is 0 (mapped axis is the first dimension of the output).

  • name (str, optional) – Name for the transformed function. If None, derives a name from the original function.

Returns:

vectorized_func – A function with the same signature as func that operates on batches of inputs.

Return type:

callable

Notes

When to use this function:

  • When you need to apply the same operation to many inputs efficiently

  • To convert a single-example function into one that handles batches

  • To selectively vectorize over some arguments while broadcasting others

  • To “unflatten” tree-structured data by mapping the unravel function

Conceptual model:

vmap transforms functions to operate along array axes. For example, a function f(x) that takes a vector and returns a scalar can be transformed into one that takes a batch of vectors (an array) and returns a batch of scalars (a vector), without explicitly writing loops.

Each argument can be mapped differently: - Mapped arguments (in_axes is an int): Batched processing along the specified axis - Broadcasted arguments (in_axes is None): Same value used for all batch elements

The vectorized function ensures that all mapped arguments have the same size along their mapped dimensions.

Examples

Basic vectorization of a dot product:

>>> import numpy as np
>>> import archimedes as arc
>>>
>>> def dot(a, b):
...     return np.dot(a, b)
>>>
>>> # Vectorize to compute multiple dot products at once
>>> batched_dot = arc.vmap(dot)
>>>
>>> # Input: batch of vectors (3 vectors of length 2)
>>> x = np.array([[1, 2], [3, 4], [5, 6]])
>>> y = np.array([[7, 8], [9, 10], [11, 12]])
>>>
>>> # Output: batch of scalars (3 dot products)
>>> batched_dot(x, y)
array([ 23,  67, 127])

Working with structured data (PyTrees):

>>> from archimedes import struct
>>>
>>> @struct.pytree_node
>>> class Particle:
...     x: np.ndarray
...     v: np.ndarray
>>>
>>> def update(p, dt):
...     return p.replace(x=p.x + dt * p.v)
>>>
>>> # Vectorize to update multiple particles at once
>>> map_update = arc.vmap(update, in_axes=(0, None))
>>>
>>> # Batch of 10 particles
>>> x = np.random.randn(10, 3)  # 10 particles in 3D space
>>> v = np.random.randn(10, 3)
>>> particles = Particle(x=x, v=v)
>>>
>>> # Update all 10 particles at once
>>> new_particles = map_update(particles)

See also

scan

Transform that applies a function sequentially to array elements