diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index 22cdb25821..0a12442a97 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -13,5 +13,6 @@ import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.sparse +import pytensor.link.jax.dispatch.blockwise # isort: on diff --git a/pytensor/link/jax/dispatch/blockwise.py b/pytensor/link/jax/dispatch/blockwise.py new file mode 100644 index 0000000000..5e691c141b --- /dev/null +++ b/pytensor/link/jax/dispatch/blockwise.py @@ -0,0 +1,29 @@ +import jax.numpy as jnp + +from pytensor.graph import FunctionGraph +from pytensor.link.jax.dispatch import jax_funcify +from pytensor.tensor.blockwise import Blockwise + + +@jax_funcify.register(Blockwise) +def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): + signature = op.signature + core_node = op._create_dummy_core_node(node.inputs) + core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) + tuple_core_fn = jax_funcify(core_fgraph) + + if len(node.outputs) == 1: + + def core_fn(*inputs): + return tuple_core_fn(*inputs)[0] + + else: + core_fn = tuple_core_fn + + vect_fn = jnp.vectorize(core_fn, signature=signature) + + def blockwise_fn(*inputs): + op._check_runtime_broadcast(node, inputs) + return vect_fn(*inputs) + + return blockwise_fn diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 59dcff5200..320cce170b 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -355,12 +355,30 @@ def core_func(*inner_inputs): self._gufunc = np.vectorize(core_func, signature=self.signature) return self._gufunc + def _check_runtime_broadcast(self, node, inputs): + batch_ndim = self._batch_ndim_from_outputs(node.outputs) + + for dims_and_bcast in zip( + *[ + zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim]) + for input, sinput in zip(inputs, node.inputs) + ] + ): + if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: + raise ValueError( + "Runtime broadcasting not allowed. " + "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) + def perform(self, node, inputs, output_storage): gufunc = self._gufunc if gufunc is None: gufunc = self._create_gufunc(node) + self._check_runtime_broadcast(node, inputs) + res = gufunc(*inputs) if not isinstance(res, tuple): res = (res,) diff --git a/tests/link/jax/test_blockwise.py b/tests/link/jax/test_blockwise.py new file mode 100644 index 0000000000..64569b0274 --- /dev/null +++ b/tests/link/jax/test_blockwise.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest + +from pytensor import config +from pytensor.graph import FunctionGraph +from pytensor.tensor import tensor +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.math import Dot, matmul +from tests.link.jax.test_basic import compare_jax_and_py +from tests.tensor.test_blockwise import check_blockwise_runtime_broadcasting + + +jax = pytest.importorskip("jax") + + +def test_runtime_broadcasting(): + check_blockwise_runtime_broadcasting("JAX") + + +# Equivalent blockwise to matmul but with dumb signature +odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") + + +@pytest.mark.parametrize("matmul_op", (matmul, odd_matmul)) +def test_matmul(matmul_op): + rng = np.random.default_rng(14) + a = tensor("a", shape=(2, 3, 5)) + b = tensor("b", shape=(2, 5, 3)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b) + ] + + out = matmul_op(a, b) + assert isinstance(out.owner.op, Blockwise) + fg = FunctionGraph([a, b], [out]) + fn, _ = compare_jax_and_py(fg, test_values) + + # Check we are not adding any unnecessary stuff + jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) + jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul") + expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values)) + assert jaxpr == expected_jaxpr diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 11855c4048..1a36c57f45 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -5,7 +5,7 @@ import pytest import pytensor -from pytensor import config +from pytensor import config, function from pytensor.gradient import grad from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node @@ -38,6 +38,56 @@ def test_vectorize_blockwise(): assert new_vect_node.inputs[0] is tns4 +def check_blockwise_runtime_broadcasting(mode): + a = tensor("a", shape=(None, 3, 5)) + b = tensor("b", shape=(None, 5, 3)) + + out = a @ b + fn = function([a, b], out, mode=mode) + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) + + for valid_test_values in [ + ( + np.ones((2, 3, 5)).astype(config.floatX), + np.ones((2, 5, 3)).astype(config.floatX), + ), + ( + np.ones((1, 3, 5)).astype(config.floatX), + np.ones((1, 5, 3)).astype(config.floatX), + ), + ]: + batch_dim = valid_test_values[0].shape[0] + np.testing.assert_allclose( + fn(*valid_test_values), np.full((batch_dim, 3, 3), 5.0) + ) + + for invalid_test_values in [ + ( + np.ones((1, 3, 5)).astype(config.floatX), + np.ones((2, 5, 3)).astype(config.floatX), + ), + ( + np.ones((2, 3, 5)).astype(config.floatX), + np.ones((1, 5, 3)).astype(config.floatX), + ), + ]: + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + fn(*invalid_test_values) + + invalid_test_values = ( + np.ones((2, 3, 5)).astype(config.floatX), + np.ones((3, 5, 3)).astype(config.floatX), + ) + # Error message is backend specific + with pytest.raises(ValueError): + fn(*invalid_test_values) + + +@pytest.mark.parametrize("mode", ("FAST_COMPILE", "FAST_RUN")) +def test_runtime_broadcast(mode): + check_blockwise_runtime_broadcasting(mode) + + class TestOp(Op): def make_node(self, *inputs): return Apply(self, inputs, [i.type() for i in inputs])