Skip to content

Commit

Permalink
Support Blockwise in JAX backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 9, 2023
1 parent 9a0e937 commit 2af4e05
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.blockwise

# isort: on
29 changes: 29 additions & 0 deletions pytensor/link/jax/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import jax.numpy as jnp

from pytensor.graph import FunctionGraph
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blockwise import Blockwise


@jax_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
signature = op.signature
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
tuple_core_fn = jax_funcify(core_fgraph)

if len(node.outputs) == 1:

def core_fn(*inputs):
return tuple_core_fn(*inputs)[0]

else:
core_fn = tuple_core_fn

vect_fn = jnp.vectorize(core_fn, signature=signature)

def blockwise_fn(*inputs):
op._check_runtime_broadcast(node, inputs)
return vect_fn(*inputs)

return blockwise_fn
41 changes: 41 additions & 0 deletions tests/link/jax/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import jax
import numpy as np
import pytest
from jax import make_jaxpr
from tensor.test_blockwise import check_blockwise_runtime_broadcasting

from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul
from tests.link.jax.test_basic import compare_jax_and_py


def test_runtime_broadcasting():
check_blockwise_runtime_broadcasting("JAX")


# Equivalent blockwise to matmul but with dumb signature
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")


@pytest.mark.parametrize("matmul_op", (matmul, odd_matmul))
def test_matmul(matmul_op):
rng = np.random.default_rng(14)
a = tensor("a", shape=(2, 3, 5))
b = tensor("b", shape=(2, 5, 3))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b)
]

out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out])
fn, _ = compare_jax_and_py(fg, test_values)

# Check we are not adding any unnecessary stuff
jaxpr = str(make_jaxpr(fn.vm.jit_fn)(*test_values))
jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul")
expected_jaxpr = str(make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values))
assert jaxpr == expected_jaxpr

0 comments on commit 2af4e05

Please sign in to comment.