7. Generating Python

When iterating on a function design, it is sometimes useful to be able to evaluate the function in a python script or notebook. To that end, wrenfold can generate python functions that invoke the NumPy/PyTorch, and JAX APIs.

To directly generate a python callable, we can use generate_python():

# `step_np_func` will be a callable that accepts scalar `x`:
# `step_code` is a string of python code.
step_np_func, step_code = code_generation.generate_python(func=step_clamped)

y, df = step_np_func(x=0.32, compute_df=True)

print(y)  # prints: 0.241664
print(df) # prints: array([[1.3055999], [2.16]], dtype=float64)

The listing above produces a callable step_np_func that operates on NumPy types. Note that optional output arguments are converted into optional return values when targeting Python. We pass compute_df=True to request that output df be computed, and it is returned in the output tuple.

When targeting JAX or PyTorch, it is advantageous to leave conditionals like wrenfold.sym.where() in ternary form and convert them to jax.where or torch.where instead of Python conditional logic. For example:

# Jax does not support float64, so we make sure we generate code that uses float32 types.
geneator = code_generation.PythonGenerator(
    target=code_generation.PythonGeneratorTarget.JAX,
    float_width=code_generation.PythonGeneratorFloatWidth.Float32
)

step_jax_func, step_code = code_generation.generate_python(
    func=step_clamped, generator=generator, convert_ternaries=False)

print(step_code)
# The generated JAX function:
def step_clamped(x: jnp.ndarray, compute_df: bool) -> T.Tuple[jnp.ndarray, jnp.ndarray]:
    v002 = x
    v006 = jnp.where(
        (v002 < jnp.asarray(0, dtype=jnp.float32)),
        jnp.asarray(0, dtype=jnp.float32),
        v002,
    )
    # ... output truncated ...
    return (
        v009
        * v009
        * (
            jnp.asarray(3, dtype=jnp.float32) + jnp.asarray(2, dtype=jnp.float32) * v043
        ),
        df,
    )

By leaving conditionals in ternary form (or “element selection” form), we retain the ability to use them during back-propagation. This behavior can be disabled (thereby producing if-statements) by specifying convert_ternaries=True.

Tip

For longer form examples of python generation, see jax_camera_model and cart-pole.

7.1. JIT compilation with Numba

wrenfold-generated NumPy code is intended to be compatible with numba JIT compilation. To JIT a generated function, use the PythonGeneratorTarget.Numpy target and then pass the output function to numba.njit:

import numba

# NumPy is the default argument for `generator`, but in this instance (as an example) we will
# customize the generation to use float32 for all types:
generator = code_generation.PythonGenerator(
    float_width=code_generation.PythonGeneratorFloatWidth.Float32
)

step_func, _ = code_generation.generate_python(func=step_clamped, generator=generator)

jit_compiled_func = numba.njit(step_func)

# For the float64 version we can simply do:
step_func, _ = code_generation.generate_python(func=step_clamped)
jit_compiled_func = numba.njit(step_func)