Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite for matmul when only one of the inputs has batched dimensions #558

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
constant,
extract_constant,
get_underlying_scalar_constant_value,
moveaxis,
ones_like,
register_infer_shape,
switch,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
Expand Down Expand Up @@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node):
return ret


@register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
def local_batched_matmul_to_core_matmul(fgraph, node):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.

Example, if x has batch dimensions, but y not:
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])

It also works when y has batch dimensions, but x not.
"""

# Check whether we have a matmul operation in this node
if not (
isinstance(node.op.core_op, Dot)
and len(node.op.inputs_sig[0]) == 2
and len(node.op.inputs_sig[1]) == 2
):
return None

x, y = node.inputs
batch_ndim = node.op.batch_ndim(node)

# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
y.type.broadcastable[:-2]
):
x_stacked = x.reshape((-1, x.shape[-1]))
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
return [out]

# Otherwise, check if y has batch dimension, but x not
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
x.type.broadcastable[:-2]
):
# For the y batch case we need to first move the batch axes and then reshape
# y.shape == (*b, k, n)
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
out_stacked_tr = out_stacked.reshape(
(x.shape[-2], *y.shape[:-2], y.shape[-1])
) # (m, *b, n)
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
return [out]

# Both x and y have batch dimensions, nothing to do here
return None


def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
Expand Down
49 changes: 49 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from pytensor.tensor.math import abs as pt_abs
Expand Down Expand Up @@ -4427,3 +4428,51 @@ def test_polygamma_specialization():
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)


@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="Rewrite is only relevant in FAST_RUN",
)
def test_local_batched_matmul_to_core_matmul():
rng = np.random.default_rng(seed=4433)

# x is batched but not y
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)

fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)

x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)

# y is batched but not x
x = pt.tensor("x", shape=(1, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)

fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)

x_test = rng.normal(size=(1, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)

# Both x and y are batched, rewrite does not apply
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y

fn = pytensor.function([x, y], out)
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)