From 23ad3c4e8dfef1061ec98e1e77cce26156a51b94 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sun, 12 May 2024 22:14:57 +0800 Subject: [PATCH 01/16] Started SVD rewrite --- .gitignore | 1 + pytensor/tensor/rewriting/linalg.py | 5 ++- test.ipynb | 53 +++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 test.ipynb diff --git a/.gitignore b/.gitignore index dfe862b868..e7dab96568 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,4 @@ pytensor-venv/ testing-report.html coverage.xml .coverage.* +pics \ No newline at end of file diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cdb1e59101..ff2887831e 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -4,7 +4,10 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph -from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter +from pytensor.graph.rewriting.basic import ( + copy_stack_trace, + node_rewriter, +) from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000000..73308d020f --- /dev/null +++ b/test.ipynb @@ -0,0 +1,53 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The output file is available at ./pics/symbolic_graph_rewrite.png\n" + ] + } + ], + "source": [ + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import numpy as np\n", + "from pytensor.tensor.type import matrix\n", + "from pytensor.tensor.linalg import svd\n", + "\n", + "a_pt = matrix(\"a\")\n", + "s = svd(a_pt, full_matrices=False, compute_uv=False)\n", + "J, updates = pytensor.scan(lambda i, s, a_pt : pt.grad(s[i], a_pt), sequences=pt.arange(s.shape[0]), non_sequences=[s, a_pt])\n", + "f = pytensor.function([a_pt], J, updates=updates)\n", + "f([[1, 2], [3, 4]])\n", + "pytensor.printing.pydotprint(f, outfile=\"./pics/symbolic_graph_rewrite.png\", var_with_name_simple=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytensor-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From d4a89ad2362d9fc0feb44ac010417a43b155402f Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Tue, 14 May 2024 21:36:46 +0800 Subject: [PATCH 02/16] Added rewrite to fix svd_graph_rewrite --- pytensor/tensor/rewriting/linalg.py | 29 ++++++++++++++++ test.ipynb | 53 ----------------------------- 2 files changed, 29 insertions(+), 53 deletions(-) delete mode 100644 test.ipynb diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index ff2887831e..99ff4ee075 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -14,6 +14,7 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( + SVD, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -380,3 +381,31 @@ def local_lift_through_linalg( return [block_diag(*inner_matrices)] else: raise NotImplementedError # pragma: no cover + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([SVD]) +def local_svd_uv_simplify(fgraph, node): + """If we have more than one `SVD` `Op`s and at least one has keyword argument + `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere + and allow `pytensor` to re-use the decomposition outputs instead of recomputing. + """ + (x,) = node.inputs + svd_count = 0 + compute_uv = False + not_compute_uv_svd_list = [] + + for cl, _ in fgraph.clients[x]: + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + svd_count += 1 + if (not compute_uv) and cl.op.core_op.compute_uv: + compute_uv = True + if not cl.op.core_op.compute_uv: + not_compute_uv_svd_list.append(cl) + + if svd_count > 1 and compute_uv: + for cl in not_compute_uv_svd_list: + cl.op.core_op.compute_uv = True + return [cl.outputs[0] for cl in not_compute_uv_svd_list] diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 73308d020f..0000000000 --- a/test.ipynb +++ /dev/null @@ -1,53 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The output file is available at ./pics/symbolic_graph_rewrite.png\n" - ] - } - ], - "source": [ - "import pytensor\n", - "import pytensor.tensor as pt\n", - "import numpy as np\n", - "from pytensor.tensor.type import matrix\n", - "from pytensor.tensor.linalg import svd\n", - "\n", - "a_pt = matrix(\"a\")\n", - "s = svd(a_pt, full_matrices=False, compute_uv=False)\n", - "J, updates = pytensor.scan(lambda i, s, a_pt : pt.grad(s[i], a_pt), sequences=pt.arange(s.shape[0]), non_sequences=[s, a_pt])\n", - "f = pytensor.function([a_pt], J, updates=updates)\n", - "f([[1, 2], [3, 4]])\n", - "pytensor.printing.pydotprint(f, outfile=\"./pics/symbolic_graph_rewrite.png\", var_with_name_simple=True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytensor-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From d4a2f2f9a104bc89ba11017043295861a41b5463 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sun, 12 May 2024 22:14:57 +0800 Subject: [PATCH 03/16] Started SVD rewrite --- .gitignore | 1 + pytensor/tensor/rewriting/linalg.py | 5 ++- test.ipynb | 53 +++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 test.ipynb diff --git a/.gitignore b/.gitignore index dfe862b868..e7dab96568 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,4 @@ pytensor-venv/ testing-report.html coverage.xml .coverage.* +pics \ No newline at end of file diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cdb1e59101..ff2887831e 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -4,7 +4,10 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph -from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter +from pytensor.graph.rewriting.basic import ( + copy_stack_trace, + node_rewriter, +) from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000000..73308d020f --- /dev/null +++ b/test.ipynb @@ -0,0 +1,53 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The output file is available at ./pics/symbolic_graph_rewrite.png\n" + ] + } + ], + "source": [ + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import numpy as np\n", + "from pytensor.tensor.type import matrix\n", + "from pytensor.tensor.linalg import svd\n", + "\n", + "a_pt = matrix(\"a\")\n", + "s = svd(a_pt, full_matrices=False, compute_uv=False)\n", + "J, updates = pytensor.scan(lambda i, s, a_pt : pt.grad(s[i], a_pt), sequences=pt.arange(s.shape[0]), non_sequences=[s, a_pt])\n", + "f = pytensor.function([a_pt], J, updates=updates)\n", + "f([[1, 2], [3, 4]])\n", + "pytensor.printing.pydotprint(f, outfile=\"./pics/symbolic_graph_rewrite.png\", var_with_name_simple=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytensor-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 03c8e6fe3f769023e589a3152584f4e1320c3272 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Tue, 14 May 2024 21:36:46 +0800 Subject: [PATCH 04/16] Added rewrite to fix svd_graph_rewrite --- pytensor/tensor/rewriting/linalg.py | 29 ++++++++++++++++ test.ipynb | 53 ----------------------------- 2 files changed, 29 insertions(+), 53 deletions(-) delete mode 100644 test.ipynb diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index ff2887831e..99ff4ee075 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -14,6 +14,7 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( + SVD, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -380,3 +381,31 @@ def local_lift_through_linalg( return [block_diag(*inner_matrices)] else: raise NotImplementedError # pragma: no cover + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([SVD]) +def local_svd_uv_simplify(fgraph, node): + """If we have more than one `SVD` `Op`s and at least one has keyword argument + `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere + and allow `pytensor` to re-use the decomposition outputs instead of recomputing. + """ + (x,) = node.inputs + svd_count = 0 + compute_uv = False + not_compute_uv_svd_list = [] + + for cl, _ in fgraph.clients[x]: + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + svd_count += 1 + if (not compute_uv) and cl.op.core_op.compute_uv: + compute_uv = True + if not cl.op.core_op.compute_uv: + not_compute_uv_svd_list.append(cl) + + if svd_count > 1 and compute_uv: + for cl in not_compute_uv_svd_list: + cl.op.core_op.compute_uv = True + return [cl.outputs[0] for cl in not_compute_uv_svd_list] diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 73308d020f..0000000000 --- a/test.ipynb +++ /dev/null @@ -1,53 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The output file is available at ./pics/symbolic_graph_rewrite.png\n" - ] - } - ], - "source": [ - "import pytensor\n", - "import pytensor.tensor as pt\n", - "import numpy as np\n", - "from pytensor.tensor.type import matrix\n", - "from pytensor.tensor.linalg import svd\n", - "\n", - "a_pt = matrix(\"a\")\n", - "s = svd(a_pt, full_matrices=False, compute_uv=False)\n", - "J, updates = pytensor.scan(lambda i, s, a_pt : pt.grad(s[i], a_pt), sequences=pt.arange(s.shape[0]), non_sequences=[s, a_pt])\n", - "f = pytensor.function([a_pt], J, updates=updates)\n", - "f([[1, 2], [3, 4]])\n", - "pytensor.printing.pydotprint(f, outfile=\"./pics/symbolic_graph_rewrite.png\", var_with_name_simple=True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytensor-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 4a1be4bb39666cad036caf06aeb2499fd0dafa70 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Thu, 16 May 2024 21:12:09 +0800 Subject: [PATCH 05/16] Fixed logic error of SVD node local rewrite: tried to rewrite globally --- .gitignore | 3 ++- pytensor/tensor/rewriting/linalg.py | 14 ++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index e7dab96568..f8f6072b6a 100644 --- a/.gitignore +++ b/.gitignore @@ -56,4 +56,5 @@ pytensor-venv/ testing-report.html coverage.xml .coverage.* -pics \ No newline at end of file +pics +*.ipynb \ No newline at end of file diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 99ff4ee075..e2b1a2a491 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -393,19 +393,13 @@ def local_svd_uv_simplify(fgraph, node): and allow `pytensor` to re-use the decomposition outputs instead of recomputing. """ (x,) = node.inputs - svd_count = 0 compute_uv = False - not_compute_uv_svd_list = [] for cl, _ in fgraph.clients[x]: if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - svd_count += 1 if (not compute_uv) and cl.op.core_op.compute_uv: compute_uv = True - if not cl.op.core_op.compute_uv: - not_compute_uv_svd_list.append(cl) - - if svd_count > 1 and compute_uv: - for cl in not_compute_uv_svd_list: - cl.op.core_op.compute_uv = True - return [cl.outputs[0] for cl in not_compute_uv_svd_list] + break + if compute_uv and not node.op.compute_uv: + full_matrices = node.op.full_matrices + return [SVD(x, full_matrices=full_matrices, compute_uv=compute_uv)] From b999d683ffd92eee2860d0cd4f456120c0fd079f Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Thu, 16 May 2024 21:30:21 +0800 Subject: [PATCH 06/16] Fixed typo error --- pytensor/tensor/rewriting/linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index e2b1a2a491..a1801d39f3 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -400,6 +400,7 @@ def local_svd_uv_simplify(fgraph, node): if (not compute_uv) and cl.op.core_op.compute_uv: compute_uv = True break + if compute_uv and not node.op.compute_uv: full_matrices = node.op.full_matrices - return [SVD(x, full_matrices=full_matrices, compute_uv=compute_uv)] + return [SVD(full_matrices=full_matrices, compute_uv=compute_uv)] From 55ad9311b4c80dc9517cb228b9a914c57e911780 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 22 May 2024 15:39:56 +0800 Subject: [PATCH 07/16] Refactored logic for SVD to support all 3 cases --- pytensor/tensor/rewriting/linalg.py | 42 ++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a1801d39f3..9215e9b94b 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -22,6 +22,7 @@ inv, kron, pinv, + svd, ) from pytensor.tensor.rewriting.basic import ( register_canonicalize, @@ -393,14 +394,35 @@ def local_svd_uv_simplify(fgraph, node): and allow `pytensor` to re-use the decomposition outputs instead of recomputing. """ (x,) = node.inputs - compute_uv = False - for cl, _ in fgraph.clients[x]: - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if (not compute_uv) and cl.op.core_op.compute_uv: - compute_uv = True - break - - if compute_uv and not node.op.compute_uv: - full_matrices = node.op.full_matrices - return [SVD(full_matrices=full_matrices, compute_uv=compute_uv)] + if node.compute_uv: + # compute_uv=True returns [u, s, v]. + # if at least u or v is used, no need to rewrite this node. + if ( + fgraph.clients[node.outputs[0]] is not None + or fgraph.clients[node.outputs[2]] is not None + ): + return + + # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. + # First, iterate to see if there is an SVD Op that can be reused. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if not cl.op.core_op.compute_uv: + return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} + + # If no SVD reusable, return a new one. + return [svd(x, full_matrices=node.full_matrices, compute_uv=False)] + + else: + # compute_uv=False returns [s]. + # We want rewrite if there is another one with compute_uv=True. + # For this case, just reuse the `s` from the one with compute_uv=True. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if cl.op.core_op.compute_uv: + return [cl.outputs[1]] From 0337e9dd6547df7a033c86f8567771524f7567f2 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 22 May 2024 15:40:54 +0800 Subject: [PATCH 08/16] Reverted .gitignore --- .gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index f8f6072b6a..13336ea2fc 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,4 @@ pytensor-venv/ .vscode/ testing-report.html coverage.xml -.coverage.* -pics -*.ipynb \ No newline at end of file +.coverage.* \ No newline at end of file From ecc62ae35a4671d0291bb417bb35839f3e0acd10 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Sat, 25 May 2024 08:59:12 +0800 Subject: [PATCH 09/16] Update pytensor/tensor/rewriting/linalg.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 9215e9b94b..5f5cde19b4 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -414,7 +414,7 @@ def local_svd_uv_simplify(fgraph, node): return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} # If no SVD reusable, return a new one. - return [svd(x, full_matrices=node.full_matrices, compute_uv=False)] + return {"remove": [node.outputs[0], node.ouputs[2]], node.outputs[1]: svd(x, full_matrices=node.full_matrices, compute_uv=False)} else: # compute_uv=False returns [s]. From 8ba51198ff1e2bdb572b5a58c66d2ffacdd6447b Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sat, 25 May 2024 10:44:17 +0800 Subject: [PATCH 10/16] Added unittest for SVD rewrite --- pytensor/tensor/rewriting/linalg.py | 5 ++++- tests/tensor/rewriting/test_linalg.py | 31 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 5f5cde19b4..d4ded1f6c7 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -414,7 +414,10 @@ def local_svd_uv_simplify(fgraph, node): return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} # If no SVD reusable, return a new one. - return {"remove": [node.outputs[0], node.ouputs[2]], node.outputs[1]: svd(x, full_matrices=node.full_matrices, compute_uv=False)} + return { + "remove": [node.outputs[0], node.ouputs[2]], + node.outputs[1]: svd(x, full_matrices=node.full_matrices, compute_uv=False), + } else: # compute_uv=False returns [s]. diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 1e9d6194db..563145c5a4 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -15,11 +15,13 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.nlinalg import ( + SVD, Det, KroneckerProduct, MatrixInverse, MatrixPinv, matrix_inverse, + svd, ) from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.slinalg import ( @@ -390,3 +392,32 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) + + +def test_local_svd_uv_simplify(): + a = matrix("a") + s_1 = svd(a, full_matrices=False, compute_uv=False) + _, s_2, _ = svd(a, full_matrices=False, compute_uv=True) + # full_matrices = True is not supported for grad of svd + gs = pt.grad(pt.sum(s_1), a) + + # 1. compute_uv=False needs rewriting with compute_uv=True + f_1 = pytensor.function([a], gs) + nodes = f_1.maker.fgraph.toposort() + for node in nodes: + if isinstance(node, SVD): + assert node.compute_uv + + # 2. compute_uv=True needs rewriting with compute=False, reuse node + f_2 = pytensor.function([a], [s_1, s_2]) + nodes = f_2.maker.fgraph.toposort() + for node in nodes: + if isinstance(node, SVD): + assert not node.compute_uv + + # 3. compute_uv=True needs rewriting with compute=False, create new node + f_3 = pytensor.function([a], [s_2]) + nodes = f_3.maker.fgraph.toposort() + for node in nodes: + if isinstance(node, SVD): + assert not node.compute_uv From 1c30ee97b40b5ca91b6c07e5ba62dd9e0b9d0192 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sun, 2 Jun 2024 20:53:33 +0800 Subject: [PATCH 11/16] Added test cases for SVD rewrite --- pytensor/tensor/rewriting/linalg.py | 9 ++--- tests/tensor/rewriting/test_linalg.py | 55 ++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index d4ded1f6c7..7b9886c15b 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -388,20 +388,17 @@ def local_lift_through_linalg( @register_stabilize @register_specialize @node_rewriter([SVD]) -def local_svd_uv_simplify(fgraph, node): +def svd_uv_merge(fgraph, node): """If we have more than one `SVD` `Op`s and at least one has keyword argument `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere and allow `pytensor` to re-use the decomposition outputs instead of recomputing. """ (x,) = node.inputs - if node.compute_uv: + if node.op.compute_uv: # compute_uv=True returns [u, s, v]. # if at least u or v is used, no need to rewrite this node. - if ( - fgraph.clients[node.outputs[0]] is not None - or fgraph.clients[node.outputs[2]] is not None - ): + if fgraph.clients[node.outputs[0]] or fgraph.clients[node.outputs[2]]: return # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 563145c5a4..ad3cd109ac 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -394,30 +394,65 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) -def test_local_svd_uv_simplify(): +def test_svd_uv_merge(): a = matrix("a") s_1 = svd(a, full_matrices=False, compute_uv=False) _, s_2, _ = svd(a, full_matrices=False, compute_uv=True) + _, s_3, _ = svd(a, full_matrices=True, compute_uv=True) + u_4, s_4, v_4 = svd(a, full_matrices=False, compute_uv=True) + # `grad` will introduces an SVD Op with compute_uv=True # full_matrices = True is not supported for grad of svd gs = pt.grad(pt.sum(s_1), a) # 1. compute_uv=False needs rewriting with compute_uv=True f_1 = pytensor.function([a], gs) - nodes = f_1.maker.fgraph.toposort() + nodes = f_1.maker.fgraph.apply_nodes + svd_counter = 0 for node in nodes: - if isinstance(node, SVD): - assert node.compute_uv + if isinstance(node.op, SVD): + assert node.op.compute_uv + svd_counter += 1 + assert svd_counter == 1 # 2. compute_uv=True needs rewriting with compute=False, reuse node f_2 = pytensor.function([a], [s_1, s_2]) - nodes = f_2.maker.fgraph.toposort() + nodes = f_2.maker.fgraph.apply_nodes + svd_counter = 0 for node in nodes: - if isinstance(node, SVD): - assert not node.compute_uv + if isinstance(node.op, SVD): + assert not node.op.compute_uv + svd_counter += 1 + assert svd_counter == 1 # 3. compute_uv=True needs rewriting with compute=False, create new node + # full_matrices needs to retain the value f_3 = pytensor.function([a], [s_2]) - nodes = f_3.maker.fgraph.toposort() + nodes = f_3.maker.fgraph.apply_nodes + svd_counter = 0 for node in nodes: - if isinstance(node, SVD): - assert not node.compute_uv + if isinstance(node.op, SVD): + assert not node.op.compute_uv + assert not node.op.full_matrices + svd_counter += 1 + assert svd_counter == 1 + + # Case 2 of 3. for a different full_matrices + f_4 = pytensor.function([a], [s_3]) + nodes = f_4.maker.fgraph.apply_nodes + svd_counter = 0 + for node in nodes: + if isinstance(node.op, SVD): + assert not node.op.compute_uv + assert node.op.full_matrices + svd_counter += 1 + assert svd_counter == 1 + + # 4. No rewrite should happen + f_5 = pytensor.function([a], [u_4]) + nodes = f_5.maker.fgraph.apply_nodes + svd_counter = 0 + for node in nodes: + if isinstance(node.op, SVD): + assert node.op.compute_uv + svd_counter += 1 + assert svd_counter == 1 From 551a3505ad23a0348325883834dbed065e7ad99e Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 7 Jun 2024 13:42:00 +0800 Subject: [PATCH 12/16] Fix logic error in linalg rewriting --- pytensor/tensor/rewriting/linalg.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 7b9886c15b..f3ccc74eb7 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -398,7 +398,10 @@ def svd_uv_merge(fgraph, node): if node.op.compute_uv: # compute_uv=True returns [u, s, v]. # if at least u or v is used, no need to rewrite this node. - if fgraph.clients[node.outputs[0]] or fgraph.clients[node.outputs[2]]: + if ( + len(fgraph.clients[node.outputs[0]]) > 0 + or len(fgraph.clients[node.outputs[2]]) > 0 + ): return # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. @@ -408,7 +411,10 @@ def svd_uv_merge(fgraph, node): continue if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): if not cl.op.core_op.compute_uv: - return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} + return { + "remove": [node.outputs[0], node.ouputs[2]], + node.outputs[1]: cl.outputs[0], + } # If no SVD reusable, return a new one. return { From 27ff606a68801e42e44196a904963a177b1d5502 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 7 Jun 2024 13:48:37 +0800 Subject: [PATCH 13/16] Fix logic error in linalg rewriting --- pytensor/tensor/rewriting/linalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index f3ccc74eb7..58dac6334e 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -430,5 +430,8 @@ def svd_uv_merge(fgraph, node): if cl == "output": continue if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if cl.op.core_op.compute_uv: + if cl.op.core_op.compute_uv and ( + len(fgraph.clients[cl.outputs[0]]) > 0 + or len(fgraph.clients[cl.outputs[2]]) > 0 + ): return [cl.outputs[1]] From 67e1f06e33a224c4e17625739e0e3fe9defe159a Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 17 Jun 2024 12:49:27 +0000 Subject: [PATCH 14/16] Changed tracking SVD to tracking Blockwise Ric's comment --- pytensor/tensor/rewriting/linalg.py | 87 ++++++++++++++------------- tests/tensor/rewriting/test_linalg.py | 4 +- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 58dac6334e..769225d815 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -387,51 +387,54 @@ def local_lift_through_linalg( @register_canonicalize @register_stabilize @register_specialize -@node_rewriter([SVD]) +@node_rewriter([Blockwise]) def svd_uv_merge(fgraph, node): """If we have more than one `SVD` `Op`s and at least one has keyword argument `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere and allow `pytensor` to re-use the decomposition outputs instead of recomputing. """ - (x,) = node.inputs + if isinstance(node.op.core_op, SVD): + (x,) = node.inputs - if node.op.compute_uv: - # compute_uv=True returns [u, s, v]. - # if at least u or v is used, no need to rewrite this node. - if ( - len(fgraph.clients[node.outputs[0]]) > 0 - or len(fgraph.clients[node.outputs[2]]) > 0 - ): - return - - # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. - # First, iterate to see if there is an SVD Op that can be reused. - for cl, _ in fgraph.clients[x]: - if cl == "output": - continue - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if not cl.op.core_op.compute_uv: - return { - "remove": [node.outputs[0], node.ouputs[2]], - node.outputs[1]: cl.outputs[0], - } - - # If no SVD reusable, return a new one. - return { - "remove": [node.outputs[0], node.ouputs[2]], - node.outputs[1]: svd(x, full_matrices=node.full_matrices, compute_uv=False), - } - - else: - # compute_uv=False returns [s]. - # We want rewrite if there is another one with compute_uv=True. - # For this case, just reuse the `s` from the one with compute_uv=True. - for cl, _ in fgraph.clients[x]: - if cl == "output": - continue - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if cl.op.core_op.compute_uv and ( - len(fgraph.clients[cl.outputs[0]]) > 0 - or len(fgraph.clients[cl.outputs[2]]) > 0 - ): - return [cl.outputs[1]] + if node.op.core_op.compute_uv: + # compute_uv=True returns [u, s, v]. + # if at least u or v is used, no need to rewrite this node. + if ( + len(fgraph.clients[node.outputs[0]]) > 0 + or len(fgraph.clients[node.outputs[2]]) > 0 + ): + return + + # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. + # First, iterate to see if there is an SVD Op that can be reused. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if not cl.op.core_op.compute_uv: + return { + "remove": [node.outputs[0], node.outputs[2]], + node.outputs[1]: cl.outputs[0], + } + + # If no SVD reusable, return a new one. + return { + "remove": [node.outputs[0], node.outputs[2]], + node.outputs[1]: svd( + x, full_matrices=node.op.core_op.full_matrices, compute_uv=False + ), + } + + else: + # compute_uv=False returns [s]. + # We want rewrite if there is another one with compute_uv=True. + # For this case, just reuse the `s` from the one with compute_uv=True. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if cl.op.core_op.compute_uv and ( + len(fgraph.clients[cl.outputs[0]]) > 0 + or len(fgraph.clients[cl.outputs[2]]) > 0 + ): + return [cl.outputs[1]] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index ad3cd109ac..523742e356 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -399,7 +399,7 @@ def test_svd_uv_merge(): s_1 = svd(a, full_matrices=False, compute_uv=False) _, s_2, _ = svd(a, full_matrices=False, compute_uv=True) _, s_3, _ = svd(a, full_matrices=True, compute_uv=True) - u_4, s_4, v_4 = svd(a, full_matrices=False, compute_uv=True) + u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True) # `grad` will introduces an SVD Op with compute_uv=True # full_matrices = True is not supported for grad of svd gs = pt.grad(pt.sum(s_1), a) @@ -432,7 +432,6 @@ def test_svd_uv_merge(): for node in nodes: if isinstance(node.op, SVD): assert not node.op.compute_uv - assert not node.op.full_matrices svd_counter += 1 assert svd_counter == 1 @@ -453,6 +452,7 @@ def test_svd_uv_merge(): svd_counter = 0 for node in nodes: if isinstance(node.op, SVD): + assert node.op.full_matrices assert node.op.compute_uv svd_counter += 1 assert svd_counter == 1 From 9e21635bb214bf7e7524d4b1f27168ead8924364 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 28 Jun 2024 11:20:09 +0200 Subject: [PATCH 15/16] Rely on implicit remove --- pytensor/tensor/rewriting/linalg.py | 88 ++++++++++++++--------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 769225d815..30d9084449 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -393,48 +393,48 @@ def svd_uv_merge(fgraph, node): `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere and allow `pytensor` to re-use the decomposition outputs instead of recomputing. """ - if isinstance(node.op.core_op, SVD): - (x,) = node.inputs + if not isinstance(node.op.core_op, SVD): + return - if node.op.core_op.compute_uv: - # compute_uv=True returns [u, s, v]. - # if at least u or v is used, no need to rewrite this node. - if ( - len(fgraph.clients[node.outputs[0]]) > 0 - or len(fgraph.clients[node.outputs[2]]) > 0 - ): - return - - # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. - # First, iterate to see if there is an SVD Op that can be reused. - for cl, _ in fgraph.clients[x]: - if cl == "output": - continue - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if not cl.op.core_op.compute_uv: - return { - "remove": [node.outputs[0], node.outputs[2]], - node.outputs[1]: cl.outputs[0], - } - - # If no SVD reusable, return a new one. - return { - "remove": [node.outputs[0], node.outputs[2]], - node.outputs[1]: svd( - x, full_matrices=node.op.core_op.full_matrices, compute_uv=False - ), - } - - else: - # compute_uv=False returns [s]. - # We want rewrite if there is another one with compute_uv=True. - # For this case, just reuse the `s` from the one with compute_uv=True. - for cl, _ in fgraph.clients[x]: - if cl == "output": - continue - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if cl.op.core_op.compute_uv and ( - len(fgraph.clients[cl.outputs[0]]) > 0 - or len(fgraph.clients[cl.outputs[2]]) > 0 - ): - return [cl.outputs[1]] + (x,) = node.inputs + + if node.op.core_op.compute_uv: + # compute_uv=True returns [u, s, v]. + # if at least u or v is used, no need to rewrite this node. + if ( + len(fgraph.clients[node.outputs[0]]) > 0 + or len(fgraph.clients[node.outputs[2]]) > 0 + ): + return + + # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. + # First, iterate to see if there is an SVD Op that can be reused. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if not cl.op.core_op.compute_uv: + return { + node.outputs[1]: cl.outputs[0], + } + + # If no SVD reusable, return a new one. + return { + node.outputs[1]: svd( + x, full_matrices=node.op.core_op.full_matrices, compute_uv=False + ), + } + + else: + # compute_uv=False returns [s]. + # We want rewrite if there is another one with compute_uv=True. + # For this case, just reuse the `s` from the one with compute_uv=True. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if cl.op.core_op.compute_uv and ( + len(fgraph.clients[cl.outputs[0]]) > 0 + or len(fgraph.clients[cl.outputs[2]]) > 0 + ): + return [cl.outputs[1]] From 3ba3ba4f8ebfd80953ce0d399c11f94e90991c82 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 28 Jun 2024 11:22:30 +0200 Subject: [PATCH 16/16] Add back empty line --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 13336ea2fc..dfe862b868 100644 --- a/.gitignore +++ b/.gitignore @@ -55,4 +55,4 @@ pytensor-venv/ .vscode/ testing-report.html coverage.xml -.coverage.* \ No newline at end of file +.coverage.*