Represent code expressions as data structures, then transform them.

Install

pip install vexpr

Vexpr is currently in technical preview and may throw “Not Implemented” if you try something new.

Get started

Example: a custom distance metric between two lists of vectors, x1 and x2.

NumPy
PyTorch
JAX

1. Create a Vexpr

import vexpr as vp
import vexpr.numpy as vnp
import vexpr.scipy.spatial.distance as vsd

w1 = vp.symbol("w1")
w2 = vp.symbol("w2")
x1 = vp.symbol("x1")
x2 = vp.symbol("x2")

expr = vnp.sum([w1 * vsd.cdist(x1[..., [0, 1, 2]], x2[..., [0, 1, 2]]),
                w2 * vsd.cdist(x1[..., [0, 3, 4]], x2[..., [0, 3, 4]])],
               axis=0)
print(expr)
# Output: a Vexpr data structure:
#
# numpy.sum(
#   [operator.mul(
#     symbol('w1'),
#     scipy.spatial.distance.cdist(
#       operator.getitem(
#         symbol('x1'),
#         (Ellipsis, [0, 1, 2]),
#       ),
#       operator.getitem(
#         symbol('x2'),
#         (Ellipsis, [0, 1, 2]),
#       ),
#     ),
#   ),
#    operator.mul(
#     symbol('w2'),
#     scipy.spatial.distance.cdist(
#       operator.getitem(
#         symbol('x1'),
#         (Ellipsis, [0, 3, 4]),
#       ),
#       operator.getitem(
#         symbol('x2'),
#         (Ellipsis, [0, 3, 4]),
#       ),
#     ),
#   )]
#   axis=0
# )

2. Transform into a faster Vexpr that would have been difficult to write directly

import numpy as np

example_inputs = dict(
    x1=np.random.randn(10, 5),
    x2=np.random.randn(10, 5),
    w1=np.array(0.7),
    w2=np.array(0.3),
)

expr = vp.vectorize(expr, example_inputs)
print(expr)
# numpy.sum(
#   operator.mul(
#     numpy.reshape(
#       numpy.stack([symbol('w1'), symbol('w2')]),
#       (2, 1, 1),
#     ),
#     custom.scipy.cdist_multi(
#       operator.getitem(
#         symbol('x1'),
#         (Ellipsis, array([0, 1, 2, 0, 3, 4])),
#       ),
#       operator.getitem(
#         symbol('x2'),
#         (Ellipsis, array([0, 1, 2, 0, 3, 4])),
#       ),
#       lengths=array([3, 3])
#     ),
#   )
#   axis=0
# )

3. Evaluate the Vexpr, as you would if you were training w1 and w2

inputs = dict(x1=np.random.randn(12, 5),
              x2=np.random.randn(4, 5),
              w1=np.array(0.6),
              w2=np.array(0.4),)
print(vp.eval(expr, inputs))
# [[1.55860886 1.81932763 1.36601246 2.74558064]
#  [1.07449014 2.41388948 2.05383731 3.47491204]
#  [3.44607574 4.11058513 1.73149737 3.99700678]
#  [1.42342409 1.89316449 2.36516876 2.61242728]
#  [2.10589466 2.16815159 1.05028078 3.2819643 ]
#  [2.6376981  1.86969234 4.09429083 3.39908103]
#  [2.46510162 2.13610497 2.91302844 3.65995608]
#  [1.65351302 1.66339115 2.56035358 1.93349338]
#  [1.15303396 2.07962417 2.23623819 2.63961701]
#  [2.90055677 1.57172764 3.10181813 2.25698896]
#  [1.83600204 2.63654294 1.22630251 3.47381211]
#  [2.61149285 2.77062418 0.78998639 3.10032325]]

4. Use partial evaluation to precompute intermediate state, as you would before inference

parameters = dict(w1=0.6, w2=0.4)
expr = vp.partial_eval(expr, parameters)
print(expr)
# numpy.sum(
#   operator.mul(
#     array([[[0.6]],
#            [[0.4]]]),
#     custom.scipy.cdist_multi(
#       operator.getitem(
#         symbol('x1'),
#         (Ellipsis, array([0, 1, 2, 0, 3, 4])),
#       ),
#       operator.getitem(
#         symbol('x2'),
#         (Ellipsis, array([0, 1, 2, 0, 3, 4])),
#       ),
#       lengths=array([3, 3])
#     ),
#   )
#   axis=0
# )