Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite to merge multiple SVD Ops with different settings #769

Merged
merged 18 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ pytensor-venv/
.vscode/
testing-report.html
coverage.xml
.coverage.*
.coverage.*
63 changes: 62 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@

from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
det,
inv,
kron,
pinv,
svd,
)
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
Expand Down Expand Up @@ -377,3 +382,59 @@ def local_lift_through_linalg(
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def svd_uv_merge(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
if not isinstance(node.op.core_op, SVD):
return

(x,) = node.inputs

if node.op.core_op.compute_uv:
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
len(fgraph.clients[node.outputs[0]]) > 0
or len(fgraph.clients[node.outputs[2]]) > 0
):
return

# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {
node.outputs[1]: cl.outputs[0],
}

# If no SVD reusable, return a new one.
return {
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
node.outputs[1]: svd(
x, full_matrices=node.op.core_op.full_matrices, compute_uv=False
),
}

else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv and (
len(fgraph.clients[cl.outputs[0]]) > 0
or len(fgraph.clients[cl.outputs[2]]) > 0
):
return [cl.outputs[1]]
66 changes: 66 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import (
SVD,
Det,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
matrix_inverse,
svd,
)
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import (
Expand Down Expand Up @@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]

np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)


def test_svd_uv_merge():
a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False)
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
_, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True)
# `grad` will introduces an SVD Op with compute_uv=True
# full_matrices = True is not supported for grad of svd
gs = pt.grad(pt.sum(s_1), a)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

# 1. compute_uv=False needs rewriting with compute_uv=True
f_1 = pytensor.function([a], gs)
nodes = f_1.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1

# 2. compute_uv=True needs rewriting with compute=False, reuse node
f_2 = pytensor.function([a], [s_1, s_2])
nodes = f_2.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1

# 3. compute_uv=True needs rewriting with compute=False, create new node
# full_matrices needs to retain the value
f_3 = pytensor.function([a], [s_2])
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
nodes = f_3.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1

# Case 2 of 3. for a different full_matrices
f_4 = pytensor.function([a], [s_3])
nodes = f_4.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
assert node.op.full_matrices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, there's no point in worrying about whether we keep the same full_matrices or not, since they play no role when we don't compute_uv (right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I check for full_matrices parameter to make sure that the rewrite indeed reuse the Op.

svd_counter += 1
assert svd_counter == 1

# 4. No rewrite should happen
f_5 = pytensor.function([a], [u_4])
nodes = f_5.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.full_matrices
assert node.op.compute_uv
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
svd_counter += 1
assert svd_counter == 1
Loading