archimedes.callback¶
- archimedes.callback(func: Callable, result_shape_dtypes, *args) Any ¶
Execute an arbitrary Python function within an symbolic computational graph.
This function allows arbitrary Python functions to be incorporated into computational graphs. This makes it possible to use functions that cannot be traced symbolically within functions created with
compile()
.- Parameters:
func (callable) – The Python function to wrap. This function should accept the same number of arguments as provided in
*args
and should return values that can be converted to NumPy arrays.result_shape_dtypes (PyTree) – A PyTree structure that defines the expected shape and data types of the function’s output. This is used to determine the output shape of the callback wrapper without calling the function itself.
*args (Any) – Arguments to pass to
func
. These are used to determine the input and output shapes for the callback wrapper.
- Returns:
The result of calling
func(*args)
, structured as a PyTree if applicable.- Return type:
Notes
When to use this function:
When you need to incorporate external functions that cannot be directly evaluated symbolically into Archimedes computational graphs
When interfacing with legacy code or external libraries that need to be called during symbolic execution
When implementing custom numerical algorithms that don’t map cleanly to Archimedes’ symbolic operations
For testing and debugging purposes to inspect the numerical values at some point in an otherwise symbolically compiled function
The callback is executed numerically in interpreted Python at each evaluation, which means:
It won’t benefit from symbolic optimization
It cannot be differentiated through automatically
It may be slower than native symbolic operations
Note that while it is _possible_ to use this function to circumvent the requirement that Archimedes code be functionally pure, this is strongly recommended against, primarily because the number of evaluation times is not guaranteed, so side effects may be unpredictable.
Examples
>>> import numpy as np >>> import archimedes as arc >>> >>> # Define an external function >>> def custom_nonlinearity(x): ... print("Evaluating custom_nonlinearity") ... return np.tanh(x) * np.exp(-0.1 * x**2) >>> >>> >>> # Use in a compiled function >>> @arc.compile ... def model(x): ... result_shape_dtypes = x # Output has same type as input ... y = arc.callback(custom_nonlinearity, result_shape_dtypes, x) ... return y * 2 >>> >>> model(np.array([0.5, 1.5]))
See also
compile
Function for symbolically compiling Python functions
integrator
Specialized solver transformation for ODEs
implicit
Specialized solver transformation for implicit functions