From cf76824a53610e5800a387c55f9cfca12ee14dc4 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 8 Oct 2024 12:14:40 +0200 Subject: [PATCH 1/4] Fix gufunc signature of SLogDet --- pytensor/tensor/nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 6db6ae2638..e7093a82bd 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -240,7 +240,7 @@ class SLogDet(Op): """ __props__ = () - gufunc_signature = "(m, m)->(),()" + gufunc_signature = "(m,m)->(),()" gufunc_spec = ("numpy.linalg.slogdet", 1, 2) def make_node(self, x): From 130b4255b18f3288f3a74c5d2f391ead60da08c6 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 7 Oct 2024 18:10:29 +0200 Subject: [PATCH 2/4] Only run required rewrites in JAX and PyTorch tests Only run required rewrites in JAX tests Several tests ended up not testing the JAX implementation due to constant folding of inputs. --- pytensor/link/pytorch/dispatch/shape.py | 2 +- tests/link/jax/test_basic.py | 10 ++-- tests/link/jax/test_einsum.py | 13 +++-- tests/link/jax/test_extra_ops.py | 67 +++++++++---------------- tests/link/jax/test_random.py | 2 +- tests/link/jax/test_scan.py | 23 ++++----- tests/link/jax/test_sparse.py | 2 +- tests/link/jax/test_tensor_basic.py | 2 +- tests/link/pytorch/test_basic.py | 11 ++-- 9 files changed, 58 insertions(+), 74 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index 7633e28e01..e249a81a70 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -15,7 +15,7 @@ def reshape(x, shape): @pytorch_funcify.register(Shape) def pytorch_funcify_Shape(op, **kwargs): def shape(x): - return x.shape + return torch.tensor(x.shape) return shape diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 5cd2bd54c6..5e783984e0 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -6,13 +6,15 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function -from pytensor.compile.mode import get_mode +from pytensor.compile.mode import JAX, Mode from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config +from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op, get_test_value from pytensor.ifelse import ifelse +from pytensor.link.jax import JAXLinker from pytensor.raise_op import assert_op from pytensor.tensor.type import dscalar, matrices, scalar, vector @@ -26,9 +28,9 @@ def set_pytensor_flags(): jax = pytest.importorskip("jax") -# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs -jax_mode = get_mode("JAX") -py_mode = get_mode("FAST_COMPILE") +optimizer = RewriteDatabaseQuery(include=["jax"], exclude=JAX._optimizer.exclude) +jax_mode = Mode(linker=JAXLinker(), optimizer=optimizer) +py_mode = Mode(linker="py", optimizer=None) def compare_jax_and_py( diff --git a/tests/link/jax/test_einsum.py b/tests/link/jax/test_einsum.py index 9a55670c64..3bd4abd7f1 100644 --- a/tests/link/jax/test_einsum.py +++ b/tests/link/jax/test_einsum.py @@ -1,8 +1,9 @@ import numpy as np import pytest -import pytensor import pytensor.tensor as pt +from pytensor.graph import FunctionGraph +from tests.link.jax.test_basic import compare_jax_and_py jax = pytest.importorskip("jax") @@ -19,9 +20,8 @@ def test_jax_einsum(): pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes) ) out = pt.einsum(subscripts, x_pt, y_pt, z_pt) - f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX") - - np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z)) + fg = FunctionGraph([x_pt, y_pt, z_pt], [out]) + compare_jax_and_py(fg, [x, y, z]) @pytest.mark.xfail(raises=NotImplementedError) @@ -33,6 +33,5 @@ def test_ellipsis_einsum(): x_pt = pt.tensor("x", shape=x.shape) y_pt = pt.tensor("y", shape=y.shape) out = pt.einsum(subscripts, x_pt, y_pt) - f = pytensor.function([x_pt, y_pt], out, mode="JAX") - - np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y)) + fg = FunctionGraph([x_pt, y_pt], [out]) + compare_jax_and_py(fg, [x, y]) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 94c442b165..1427413379 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -1,59 +1,52 @@ import numpy as np import pytest -from packaging.version import parse as version_parse import pytensor.tensor.basic as ptb from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as pt_extra_ops -from pytensor.tensor.type import matrix +from pytensor.tensor.type import matrix, tensor from tests.link.jax.test_basic import compare_jax_and_py jax = pytest.importorskip("jax") -def set_test_value(x, v): - x.tag.test_value = v - return x - - def test_extra_ops(): a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = pt_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) out = pt_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) out = pt_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) out = pt_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) c = ptb.as_tensor(5) - out = pt_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) with pytest.raises(NotImplementedError): out = pt_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) with pytest.raises(NotImplementedError): out = pt_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) indices = np.arange(np.prod((3, 4))) out = pt_extra_ops.unravel_index(indices, (3, 4), order="C") @@ -63,40 +56,30 @@ def test_extra_ops(): ) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="JAX Numpy API does not support dynamic shapes", -) -def test_extra_ops_dynamic_shapes(): - a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) - - # This function also cannot take symbolic input. - c = ptb.as_tensor(5) +@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") +def test_bartlett_dynamic_shape(): + c = tensor(shape=(), dtype=int) out = pt_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [np.array(5)]) - multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4)) - out = pt_extra_ops.ravel_multi_index(multi_index, (3, 4)) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py( - fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False - ) - # The inputs are "concrete", yet it still has problems? - out = pt_extra_ops.Unique()( - ptb.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2))) - ) +@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") +def test_ravel_multi_index_dynamic_shape(): + x_test, y_test = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4)) + + x = tensor(shape=(None,), dtype=int) + y = tensor(shape=(None,), dtype=int) + out = pt_extra_ops.ravel_multi_index((x, y), (3, 4)) fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, []) + compare_jax_and_py(fgraph, [x_test, y_test]) -@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") -def test_unique_nonconcrete(): +@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") +def test_unique_dynamic_shape(): a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = pt_extra_ops.Unique()(a) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index dfbc888e30..f9ae5d00c1 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -705,7 +705,7 @@ def test_multinomial(): n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode="JAX") samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose( diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 61edacbc7b..ae64cad4c0 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -32,7 +32,7 @@ def test_scan_sit_sot(view): xs = xs[view] fg = FunctionGraph([x0], [xs]) test_input_vals = [np.e] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @@ -47,7 +47,7 @@ def test_scan_mit_sot(view): xs = xs[view] fg = FunctionGraph([x0], [xs]) test_input_vals = [np.full((3,), np.e)] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) @@ -74,7 +74,7 @@ def step(xtm3, xtm1, ytm4, ytm2): fg = FunctionGraph([x0, y0], [xs, ys]) test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) @@ -283,7 +283,7 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): gamma_val, delta_val, ] - compare_jax_and_py(out_fg, test_input_vals) + compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") def test_scan_mitsot_with_nonseq(): @@ -316,7 +316,7 @@ def input_step_fn(y_tm1, y_tm3, a): out_fg = FunctionGraph([a_pt], [y_scan_pt]) test_input_vals = [np.array(10.0).astype(config.floatX)] - compare_jax_and_py(out_fg, test_input_vals) + compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("x0_func", [dvector, dmatrix]) @@ -334,7 +334,6 @@ def test_nd_scan_sit_sot(x0_func, A_func): non_sequences=[A], outputs_info=[x0], n_steps=n_steps, - mode=get_mode("JAX"), ) x0_val = ( @@ -346,7 +345,7 @@ def test_nd_scan_sit_sot(x0_func, A_func): fg = FunctionGraph([x0, A], [xs]) test_input_vals = [x0_val, A_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_nd_scan_sit_sot_with_seq(): @@ -362,7 +361,6 @@ def test_nd_scan_sit_sot_with_seq(): non_sequences=[A], sequences=[x], n_steps=n_steps, - mode=get_mode("JAX"), ) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) @@ -370,7 +368,7 @@ def test_nd_scan_sit_sot_with_seq(): fg = FunctionGraph([x, A], [xs]) test_input_vals = [x_val, A_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_nd_scan_mit_sot(): @@ -384,7 +382,6 @@ def test_nd_scan_mit_sot(): outputs_info=[{"initial": x0, "taps": [-3, -1]}], non_sequences=[A, B], n_steps=10, - mode=get_mode("JAX"), ) fg = FunctionGraph([x0, A, B], [xs]) @@ -393,7 +390,7 @@ def test_nd_scan_mit_sot(): B_val = np.eye(3, dtype=config.floatX) test_input_vals = [x0_val, A_val, B_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_nd_scan_sit_sot_with_carry(): @@ -417,7 +414,7 @@ def step(x, A): A_val = np.eye(3, dtype=config.floatX) test_input_vals = [x0_val, A_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_default_mode_excludes_incompatible_rewrites(): @@ -426,7 +423,7 @@ def test_default_mode_excludes_incompatible_rewrites(): B = matrix("B") out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) fg = FunctionGraph([A, B], [out]) - compare_jax_and_py(fg, [np.eye(3), np.eye(3)]) + compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX") def test_dynamic_sequence_length(): diff --git a/tests/link/jax/test_sparse.py b/tests/link/jax/test_sparse.py index 0c377bdcd8..c53aa301af 100644 --- a/tests/link/jax/test_sparse.py +++ b/tests/link/jax/test_sparse.py @@ -51,7 +51,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op): dot_pt = op(x_pt, y_pt) fgraph = FunctionGraph(inputs, [dot_pt]) - compare_jax_and_py(fgraph, test_values) + compare_jax_and_py(fgraph, test_values, jax_mode="JAX") def test_sparse_dot_non_const_raises(): diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index afa4191b9d..0ee4a236d9 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -74,7 +74,7 @@ def test_arange_of_shape(): x = vector("x") out = ptb.arange(1, x.shape[-1], 2) fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [np.zeros((5,))]) + compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX") def test_arange_nonconcrete(): diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index bb1958f43e..8e243169a6 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -7,13 +7,15 @@ import pytensor.tensor.basic as ptb from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function -from pytensor.compile.mode import get_mode +from pytensor.compile.mode import PYTORCH, Mode from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config +from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.ifelse import ifelse +from pytensor.link.pytorch.linker import PytorchLinker from pytensor.raise_op import CheckAndRaise from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -22,8 +24,9 @@ torch = pytest.importorskip("torch") -pytorch_mode = get_mode("PYTORCH") -py_mode = get_mode("FAST_COMPILE") +optimizer = RewriteDatabaseQuery(include=[], exclude=PYTORCH._optimizer.exclude) +pytorch_mode = Mode(linker=PytorchLinker(), optimizer=optimizer) +py_mode = Mode(linker="py", optimizer=None) def compare_pytorch_and_py( @@ -220,7 +223,7 @@ def test_alloc_and_empty(): assert res.dtype == torch.float32 v = vector("v", shape=(3,), dtype="float64") - out = alloc(v, (dim0, dim1, 3)) + out = alloc(v, dim0, dim1, 3) compare_pytorch_and_py( FunctionGraph([v, dim1], [out]), [np.array([1, 2, 3]), np.array(7)], From c854bc3464f6c63ba76cb3d37732d832f751a798 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 7 Oct 2024 14:08:16 +0200 Subject: [PATCH 3/4] Fix pytest config in pyproject.toml --- .github/workflows/test.yml | 6 +++--- pyproject.toml | 6 +++--- tests/unittest_tools.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e3d2adf461..7298d5df61 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -78,7 +78,7 @@ jobs: install-jax: [0] install-torch: [0] part: - - "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link" + - "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" - "tests/sparse" @@ -97,9 +97,9 @@ jobs: part: "tests/tensor/test_math.py" - fast-compile: 1 float32: 1 - - part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link" + - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" float32: 1 - - part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link" + - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" fast-compile: 1 include: - install-numba: 1 diff --git a/pyproject.toml b/pyproject.toml index 81fe82c79c..95198d656e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,9 +116,9 @@ versionfile_source = "pytensor/_version.py" versionfile_build = "pytensor/_version.py" tag_prefix = "rel-" -[tool.pytest] -addopts = "--durations=50 --doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" -testpaths = "tests/" +[tool.pytest.ini_options] +addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py" +testpaths = ["pytensor/", "tests/"] [tool.ruff] line-length = 88 diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index a556e3a275..9134b29b65 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -27,8 +27,8 @@ def fetch_seed(pseed=None): None, which is equivalent to seeding with a random seed. Useful for seeding RandomState or Generator objects. - >>> rng = np.random.RandomState(unittest_tools.fetch_seed()) - >>> rng = np.random.default_rng(unittest_tools.fetch_seed()) + >>> rng = np.random.RandomState(fetch_seed()) + >>> rng = np.random.default_rng(fetch_seed()) """ seed = pseed or config.unittests__rseed From b5a0ca1ae782ab59df08e37df176b2f239aafa38 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 8 Oct 2024 12:17:22 +0200 Subject: [PATCH 4/4] Make xfail strict --- pyproject.toml | 1 + tests/link/jax/test_einsum.py | 1 - tests/tensor/rewriting/test_elemwise.py | 1 - 3 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 95198d656e..42c2289dde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ tag_prefix = "rel-" [tool.pytest.ini_options] addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py" testpaths = ["pytensor/", "tests/"] +xfail_strict = true [tool.ruff] line-length = 88 diff --git a/tests/link/jax/test_einsum.py b/tests/link/jax/test_einsum.py index 3bd4abd7f1..5761563066 100644 --- a/tests/link/jax/test_einsum.py +++ b/tests/link/jax/test_einsum.py @@ -24,7 +24,6 @@ def test_jax_einsum(): compare_jax_and_py(fg, [x, y, z]) -@pytest.mark.xfail(raises=NotImplementedError) def test_ellipsis_einsum(): subscripts = "...i,...i->..." x = np.random.rand(2, 5) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 82cfa884af..fc429bb596 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1057,7 +1057,6 @@ def test_big_fusion(self): for node in dlogp.maker.fgraph.toposort() ) - @pytest.mark.xfail(reason="Fails due to #1244") def test_add_mul_fusion_precedence(self): """Test that additions and multiplications are "fused together" before a `Composite` `Op` is introduced. This fusion is done by canonicalization