archimedes.switch¶
- archimedes.switch(index: int, branches: tuple[Callable, ...], *args: PyTree, name: str | None = None, kind: str = 'MX') PyTree ¶
Selectively apply one of several functions based on an index.
This function provides a conditional branching mechanism that selects and applies one of the provided branch functions based on the index value. The function is similar to a switch/case statement but can be embedded within computational graphs.
Semantically, this function is equivalent to the following Python code:
def switch(index, branches, *args): index = min(max(index, 0), len(branches) - 1) return branches[index](*args)
- Parameters:
index (int) – The branch selector. Must be an integer. If the index is out of bounds, it will be clamped to the valid range
[0, len(branches)-1]
.branches (tuple of callables) – A tuple of functions to choose from. Each function must accept the same arguments and return compatible structures.
*args (PyTree) – Arguments to pass to the selected branch function. All branches must accept these arguments.
name (str, optional) – Name for the resulting function. Used for debugging and visualization. Default is “switch_{index}”.
kind (str, optional) – The kind of symbolics to use when constructing the function. Default is “MX”.
- Returns:
The result of applying the selected branch function to the provided arguments. All branches must return the same structure.
- Return type:
PyTree
Notes
This function converts conditional branching into a computational graph construct. Unlike Python’s if/else, which doesn’t work with symbolic values,
switch
is compatible with symbolic/numeric execution.The function evaluates each branch at compilation time to ensure they have compatible output structures. At runtime, only the selected branch is executed.
Behavior notes:
If index is out of bounds, it will be clamped to the valid range.
All branches must return the same PyTree structure, or a ValueError will be raised.
At least two branches must be provided, or a ValueError will be raised.
Functions are traced at compilation time, meaning any side effects will occur for all branches during tracing, even though only one branch executes at runtime. It is strongly recommended to avoid side effects.
This function supports automatic differentiation
Examples
>>> import numpy as np >>> import archimedes as arc >>> >>> # Define functions for each branch >>> def branch0(x): ... return x**2 ... >>> def branch1(x): ... return np.sin(x) ... >>> def branch2(x): ... return -x ... >>> # Create a switch function >>> @arc.compile ... def apply_operation(x, op_index): ... return arc.switch(op_index, (branch0, branch1, branch2), x) ... >>> # Apply different branches based on the index >>> x = np.array([0.5, 1.0, 1.5]) >>> apply_operation(x, 0) # Returns x**2 >>> apply_operation(x, 1) # Returns sin(x) >>> apply_operation(x, 2) # Returns -x
# Example with a PyTree >>> def multiply(data, factor): … return {k: v * factor for k, v in data.items()} … >>> def add_offset(data, offset): … return {k: v + offset for k, v in data.items()} … >>> @arc.compile … def process_data(data, op_index, param): … return arc.switch(op_index, (multiply, add_offset), data, param) … >>> data = {“a”: np.array([1.0, 2.0]), “b”: np.array([3.0, 4.0])} >>> process_data(data, 0, 2.0) # Multiplies all values by 2.0 >>> process_data(data, 1, 1.0) # Adds 1.0 to all values
See also
np.where
Element-wise conditional selection between two arrays
scan
Functional for-loop construct for repeated operations