From 9a0e937faac1e3459b1ecca03155b53f0652d238 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 9 Nov 2023 13:05:25 +0100 Subject: [PATCH] Check for runtime broadcasting in Blockwise Ops --- pytensor/tensor/blockwise.py | 18 ++++++++++++ tests/tensor/test_blockwise.py | 52 +++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 59dcff5200..320cce170b 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -355,12 +355,30 @@ def core_func(*inner_inputs): self._gufunc = np.vectorize(core_func, signature=self.signature) return self._gufunc + def _check_runtime_broadcast(self, node, inputs): + batch_ndim = self._batch_ndim_from_outputs(node.outputs) + + for dims_and_bcast in zip( + *[ + zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim]) + for input, sinput in zip(inputs, node.inputs) + ] + ): + if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: + raise ValueError( + "Runtime broadcasting not allowed. " + "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) + def perform(self, node, inputs, output_storage): gufunc = self._gufunc if gufunc is None: gufunc = self._create_gufunc(node) + self._check_runtime_broadcast(node, inputs) + res = gufunc(*inputs) if not isinstance(res, tuple): res = (res,) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 11855c4048..1a36c57f45 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -5,7 +5,7 @@ import pytest import pytensor -from pytensor import config +from pytensor import config, function from pytensor.gradient import grad from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node @@ -38,6 +38,56 @@ def test_vectorize_blockwise(): assert new_vect_node.inputs[0] is tns4 +def check_blockwise_runtime_broadcasting(mode): + a = tensor("a", shape=(None, 3, 5)) + b = tensor("b", shape=(None, 5, 3)) + + out = a @ b + fn = function([a, b], out, mode=mode) + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) + + for valid_test_values in [ + ( + np.ones((2, 3, 5)).astype(config.floatX), + np.ones((2, 5, 3)).astype(config.floatX), + ), + ( + np.ones((1, 3, 5)).astype(config.floatX), + np.ones((1, 5, 3)).astype(config.floatX), + ), + ]: + batch_dim = valid_test_values[0].shape[0] + np.testing.assert_allclose( + fn(*valid_test_values), np.full((batch_dim, 3, 3), 5.0) + ) + + for invalid_test_values in [ + ( + np.ones((1, 3, 5)).astype(config.floatX), + np.ones((2, 5, 3)).astype(config.floatX), + ), + ( + np.ones((2, 3, 5)).astype(config.floatX), + np.ones((1, 5, 3)).astype(config.floatX), + ), + ]: + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + fn(*invalid_test_values) + + invalid_test_values = ( + np.ones((2, 3, 5)).astype(config.floatX), + np.ones((3, 5, 3)).astype(config.floatX), + ) + # Error message is backend specific + with pytest.raises(ValueError): + fn(*invalid_test_values) + + +@pytest.mark.parametrize("mode", ("FAST_COMPILE", "FAST_RUN")) +def test_runtime_broadcast(mode): + check_blockwise_runtime_broadcasting(mode) + + class TestOp(Op): def make_node(self, *inputs): return Apply(self, inputs, [i.type() for i in inputs])