7. Type annotationsΒΆ

wrenfold determines the runtime type/shape of input arguments using type annotations. Consider the function:

import wrenfold as wf
from wrenfold import sym

def rotate_point(angle: wf.FloatScalar, p: wf.Vector2):
    """Rotate a point in 2D counter-clockwise by an angle in radians."""
    R = sym.matrix([
      [sym.cos(angle), -sym.sin(angle)],
      [sym.sin(angle), sym.cos(angle)]
    ])
    p_rotated = R * p
    return [
        wf.OutputArg(p_rotated, name="p_rotated"),
        wf.OutputArg(p_rotated.jacobian([angle]), name="D_angle", is_optional=True)
    ]

The annotation wf.FloatScalar indicates that the runtime type for angle is a floating-point number. The annotation wf.Vector2 indicates that the runtime type for p is a 2x1 column vector. In C++, the generated code looks like:

template <typename Scalar, typename T1, typename T2, typename T3>
void rotate_point(const Scalar angle, const T1& p, T2&& p_rotated, T3&& D_angle)
{
  auto _p = wf::make_input_span<2, 1>(p);
  auto _p_rotated = wf::make_output_span<2, 1>(p_rotated);
  auto _D_angle = wf::make_optional_output_span<2, 1>(D_angle);

  // ...
}

The use of wf.Vector2 has resulted in the runtime code auto _p = wf::make_input_span<2, 1>(p);, which permits any type T1 (of appropriate size) that implements wf::convert_to_span to be passed to rotate_point.

In fact, wf.Vector2 is an instance of typing.Annotated:

# Snippet from type_annotations.py
Vector2 = Annotated[sym.MatrixExpr, Shape(rows=2, cols=1), "A 2x1 column vector."]

You can define custom matrix sizes in the same manner:

import typing
import wrenfold as wf
import wrenfold.sym as sym

Matrix3x9 = typing.Annotated[sym.MatrixExpr, wf.Shape(rows=3, cols=9)]

def my_func(some_input: Matrix3x9):
    # ...