Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Blockwise in JAX backend #487

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.blockwise

# isort: on
29 changes: 29 additions & 0 deletions pytensor/link/jax/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import jax.numpy as jnp

from pytensor.graph import FunctionGraph
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blockwise import Blockwise


@jax_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
signature = op.signature
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
tuple_core_fn = jax_funcify(core_fgraph)

if len(node.outputs) == 1:

def core_fn(*inputs):
return tuple_core_fn(*inputs)[0]

else:
core_fn = tuple_core_fn

vect_fn = jnp.vectorize(core_fn, signature=signature)

def blockwise_fn(*inputs):
op._check_runtime_broadcast(node, inputs)
return vect_fn(*inputs)

return blockwise_fn
18 changes: 18 additions & 0 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,30 @@ def core_func(*inner_inputs):
self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc

def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self._batch_ndim_from_outputs(node.outputs)

for dims_and_bcast in zip(
*[
zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim])
for input, sinput in zip(inputs, node.inputs)
]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)

def perform(self, node, inputs, output_storage):
gufunc = self._gufunc

if gufunc is None:
gufunc = self._create_gufunc(node)

self._check_runtime_broadcast(node, inputs)

res = gufunc(*inputs)
if not isinstance(res, tuple):
res = (res,)
Expand Down
42 changes: 42 additions & 0 deletions tests/link/jax/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import pytest

from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_blockwise import check_blockwise_runtime_broadcasting


jax = pytest.importorskip("jax")


def test_runtime_broadcasting():
check_blockwise_runtime_broadcasting("JAX")


# Equivalent blockwise to matmul but with dumb signature
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")


@pytest.mark.parametrize("matmul_op", (matmul, odd_matmul))
def test_matmul(matmul_op):
rng = np.random.default_rng(14)
a = tensor("a", shape=(2, 3, 5))
b = tensor("b", shape=(2, 5, 3))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b)
]

out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out])
fn, _ = compare_jax_and_py(fg, test_values)

# Check we are not adding any unnecessary stuff
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul")
expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values))
assert jaxpr == expected_jaxpr
52 changes: 51 additions & 1 deletion tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import pytensor
from pytensor import config
from pytensor import config, function
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
Expand Down Expand Up @@ -38,6 +38,56 @@ def test_vectorize_blockwise():
assert new_vect_node.inputs[0] is tns4


def check_blockwise_runtime_broadcasting(mode):
a = tensor("a", shape=(None, 3, 5))
b = tensor("b", shape=(None, 5, 3))

out = a @ b
fn = function([a, b], out, mode=mode)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)

for valid_test_values in [
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
batch_dim = valid_test_values[0].shape[0]
np.testing.assert_allclose(
fn(*valid_test_values), np.full((batch_dim, 3, 3), 5.0)
)

for invalid_test_values in [
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
fn(*invalid_test_values)

invalid_test_values = (
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((3, 5, 3)).astype(config.floatX),
)
# Error message is backend specific
with pytest.raises(ValueError):
fn(*invalid_test_values)


@pytest.mark.parametrize("mode", ("FAST_COMPILE", "FAST_RUN"))
def test_runtime_broadcast(mode):
check_blockwise_runtime_broadcasting(mode)


class TestOp(Op):
def make_node(self, *inputs):
return Apply(self, inputs, [i.type() for i in inputs])
Expand Down
Loading