Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite to merge multiple SVD Ops with different settings #769

Merged
merged 18 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ pytensor-venv/
.vscode/
testing-report.html
coverage.xml
.coverage.*
.coverage.*
54 changes: 53 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@

from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
det,
inv,
kron,
pinv,
svd,
)
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
Expand Down Expand Up @@ -377,3 +382,50 @@
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([SVD])
def local_svd_uv_simplify(fgraph, node):
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
(x,) = node.inputs

Check warning on line 396 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L396

Added line #L396 was not covered by tests

if node.compute_uv:
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
fgraph.clients[node.outputs[0]] is not None
or fgraph.clients[node.outputs[2]] is not None
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
):
return

Check warning on line 405 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L405

Added line #L405 was not covered by tests

# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue

Check warning on line 411 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L411

Added line #L411 was not covered by tests
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {fgraph.clients[node.outputs[1]]: cl.outputs[0]}

Check warning on line 414 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L414

Added line #L414 was not covered by tests

# If no SVD reusable, return a new one.
return {

Check warning on line 417 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L417

Added line #L417 was not covered by tests
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
"remove": [node.outputs[0], node.ouputs[2]],
node.outputs[1]: svd(x, full_matrices=node.full_matrices, compute_uv=False),
}

else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue

Check warning on line 428 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L428

Added line #L428 was not covered by tests
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only want to do this if that other node is actually using the UV. If not we would actually want to replace that node by this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be taken care by the first half at that node turn. As this is a local rewrite applied to all SVD node, each node will have its turn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if you don't want to handle that other node there's no reason to rewrite this node into it. In general it's better to do as few rewrites as possible as every time a rewrite succeeds all other candidate rewrites are rerun (until an Equilibrium is achieved and nothing changes anymore).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I like your eager approach better, it's not readable. Since SVDs are rare we don't need to over optimize

return [cl.outputs[1]]

Check warning on line 431 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L431

Added line #L431 was not covered by tests
31 changes: 31 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import (
SVD,
Det,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
matrix_inverse,
svd,
)
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import (
Expand Down Expand Up @@ -390,3 +392,32 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]

np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)


def test_local_svd_uv_simplify():
a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False)
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
# full_matrices = True is not supported for grad of svd
gs = pt.grad(pt.sum(s_1), a)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

# 1. compute_uv=False needs rewriting with compute_uv=True
f_1 = pytensor.function([a], gs)
nodes = f_1.maker.fgraph.toposort()
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
for node in nodes:
if isinstance(node, SVD):
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
assert node.compute_uv

# 2. compute_uv=True needs rewriting with compute=False, reuse node
f_2 = pytensor.function([a], [s_1, s_2])
nodes = f_2.maker.fgraph.toposort()
for node in nodes:
if isinstance(node, SVD):
assert not node.compute_uv
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved

# 3. compute_uv=True needs rewriting with compute=False, create new node
f_3 = pytensor.function([a], [s_2])
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
nodes = f_3.maker.fgraph.toposort()
for node in nodes:
if isinstance(node, SVD):
assert not node.compute_uv
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
Loading