Skip to content

Commit

Permalink
Check for runtime broadcasting in Blockwise Ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 9, 2023
1 parent 893dc18 commit 9a0e937
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
18 changes: 18 additions & 0 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,30 @@ def core_func(*inner_inputs):
self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc

def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self._batch_ndim_from_outputs(node.outputs)

for dims_and_bcast in zip(
*[
zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim])
for input, sinput in zip(inputs, node.inputs)
]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)

def perform(self, node, inputs, output_storage):
gufunc = self._gufunc

if gufunc is None:
gufunc = self._create_gufunc(node)

self._check_runtime_broadcast(node, inputs)

res = gufunc(*inputs)
if not isinstance(res, tuple):
res = (res,)
Expand Down
52 changes: 51 additions & 1 deletion tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import pytensor
from pytensor import config
from pytensor import config, function
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
Expand Down Expand Up @@ -38,6 +38,56 @@ def test_vectorize_blockwise():
assert new_vect_node.inputs[0] is tns4


def check_blockwise_runtime_broadcasting(mode):
a = tensor("a", shape=(None, 3, 5))
b = tensor("b", shape=(None, 5, 3))

out = a @ b
fn = function([a, b], out, mode=mode)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)

for valid_test_values in [
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
batch_dim = valid_test_values[0].shape[0]
np.testing.assert_allclose(
fn(*valid_test_values), np.full((batch_dim, 3, 3), 5.0)
)

for invalid_test_values in [
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
fn(*invalid_test_values)

invalid_test_values = (
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((3, 5, 3)).astype(config.floatX),
)
# Error message is backend specific
with pytest.raises(ValueError):
fn(*invalid_test_values)


@pytest.mark.parametrize("mode", ("FAST_COMPILE", "FAST_RUN"))
def test_runtime_broadcast(mode):
check_blockwise_runtime_broadcasting(mode)


class TestOp(Op):
def make_node(self, *inputs):
return Apply(self, inputs, [i.type() for i in inputs])
Expand Down

0 comments on commit 9a0e937

Please sign in to comment.