Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc testing fixes #1021

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
install-jax: [0]
install-torch: [0]
part:
- "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link"
- "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just odd to have pytensor in the middle of the ignore files

- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
- "tests/sparse"
Expand All @@ -97,9 +97,9 @@ jobs:
part: "tests/tensor/test_math.py"
- fast-compile: 1
float32: 1
- part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link"
- part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
float32: 1
- part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link"
- part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
fast-compile: 1
include:
- install-numba: 1
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ versionfile_source = "pytensor/_version.py"
versionfile_build = "pytensor/_version.py"
tag_prefix = "rel-"

[tool.pytest]
addopts = "--durations=50 --doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
testpaths = "tests/"
[tool.pytest.ini_options]
addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py"
testpaths = ["pytensor/", "tests/"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytensor is in the testpath because of --doctest-modules, so we test the code examples in the docstrings by default

xfail_strict = true

[tool.ruff]
line-length = 88
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def reshape(x, shape):
@pytorch_funcify.register(Shape)
def pytorch_funcify_Shape(op, **kwargs):
def shape(x):
return x.shape
return torch.tensor(x.shape)

return shape

Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class SLogDet(Op):
"""

__props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_signature = "(m,m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)

def make_node(self, x):
Expand Down
10 changes: 6 additions & 4 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, matrices, scalar, vector

Expand All @@ -26,9 +28,9 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax")


# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
jax_mode = get_mode("JAX")
py_mode = get_mode("FAST_COMPILE")
optimizer = RewriteDatabaseQuery(include=["jax"], exclude=JAX._optimizer.exclude)
jax_mode = Mode(linker=JAXLinker(), optimizer=optimizer)
py_mode = Mode(linker="py", optimizer=None)


def compare_jax_and_py(
Expand Down
14 changes: 6 additions & 8 deletions tests/link/jax/test_einsum.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest

import pytensor
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py


jax = pytest.importorskip("jax")
Expand All @@ -19,12 +20,10 @@ def test_jax_einsum():
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
)
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX")
fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
compare_jax_and_py(fg, [x, y, z])

np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))


@pytest.mark.xfail(raises=NotImplementedError)
def test_ellipsis_einsum():
subscripts = "...i,...i->..."
x = np.random.rand(2, 5)
Expand All @@ -33,6 +32,5 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
f = pytensor.function([x_pt, y_pt], out, mode="JAX")

np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))
fg = FunctionGraph([x_pt, y_pt], [out])
compare_jax_and_py(fg, [x, y])
67 changes: 25 additions & 42 deletions tests/link/jax/test_extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,52 @@
import numpy as np
import pytest
from packaging.version import parse as version_parse

import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.type import matrix
from pytensor.tensor.type import matrix, tensor
from tests.link.jax.test_basic import compare_jax_and_py


jax = pytest.importorskip("jax")


def set_test_value(x, v):
x.tag.test_value = v
return x


def test_extra_ops():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))

out = pt_extra_ops.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])

out = pt_extra_ops.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])

out = pt_extra_ops.diff(a, n=2, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])

out = pt_extra_ops.repeat(a, (3, 3), axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])

c = ptb.as_tensor(5)

out = pt_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])

with pytest.raises(NotImplementedError):
out = pt_extra_ops.fill_diagonal_offset(a, c, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])

with pytest.raises(NotImplementedError):
out = pt_extra_ops.Unique(axis=1)(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])

indices = np.arange(np.prod((3, 4)))
out = pt_extra_ops.unravel_index(indices, (3, 4), order="C")
Expand All @@ -63,40 +56,30 @@ def test_extra_ops():
)


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="JAX Numpy API does not support dynamic shapes",
)
def test_extra_ops_dynamic_shapes():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

# This function also cannot take symbolic input.
c = ptb.as_tensor(5)
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape():
c = tensor(shape=(), dtype=int)
out = pt_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [np.array(5)])

multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
out = pt_extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = FunctionGraph([], [out])
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)

# The inputs are "concrete", yet it still has problems?
out = pt_extra_ops.Unique()(
ptb.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2)))
)
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_ravel_multi_index_dynamic_shape():
x_test, y_test = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))

x = tensor(shape=(None,), dtype=int)
y = tensor(shape=(None,), dtype=int)
out = pt_extra_ops.ravel_multi_index((x, y), (3, 4))
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
compare_jax_and_py(fgraph, [x_test, y_test])


@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_unique_nonconcrete():
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_unique_dynamic_shape():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))

out = pt_extra_ops.Unique()(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
2 changes: 1 addition & 1 deletion tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def test_multinomial():
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
Expand Down
23 changes: 10 additions & 13 deletions tests/link/jax/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_scan_sit_sot(view):
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.e]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")


@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
Expand All @@ -47,7 +47,7 @@ def test_scan_mit_sot(view):
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.full((3,), np.e)]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")


@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
Expand All @@ -74,7 +74,7 @@ def step(xtm3, xtm1, ytm4, ytm2):

fg = FunctionGraph([x0, y0], [xs, ys])
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")


@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
Expand Down Expand Up @@ -283,7 +283,7 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
gamma_val,
delta_val,
]
compare_jax_and_py(out_fg, test_input_vals)
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")


def test_scan_mitsot_with_nonseq():
Expand Down Expand Up @@ -316,7 +316,7 @@ def input_step_fn(y_tm1, y_tm3, a):
out_fg = FunctionGraph([a_pt], [y_scan_pt])

test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals)
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")


@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
Expand All @@ -334,7 +334,6 @@ def test_nd_scan_sit_sot(x0_func, A_func):
non_sequences=[A],
outputs_info=[x0],
n_steps=n_steps,
mode=get_mode("JAX"),
)

x0_val = (
Expand All @@ -346,7 +345,7 @@ def test_nd_scan_sit_sot(x0_func, A_func):

fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")


def test_nd_scan_sit_sot_with_seq():
Expand All @@ -362,15 +361,14 @@ def test_nd_scan_sit_sot_with_seq():
non_sequences=[A],
sequences=[x],
n_steps=n_steps,
mode=get_mode("JAX"),
)

x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
A_val = np.eye(k, dtype=config.floatX)

fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")


def test_nd_scan_mit_sot():
Expand All @@ -384,7 +382,6 @@ def test_nd_scan_mit_sot():
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B],
n_steps=10,
mode=get_mode("JAX"),
)

fg = FunctionGraph([x0, A, B], [xs])
Expand All @@ -393,7 +390,7 @@ def test_nd_scan_mit_sot():
B_val = np.eye(3, dtype=config.floatX)

test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")


def test_nd_scan_sit_sot_with_carry():
Expand All @@ -417,7 +414,7 @@ def step(x, A):
A_val = np.eye(3, dtype=config.floatX)

test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")


def test_default_mode_excludes_incompatible_rewrites():
Expand All @@ -426,7 +423,7 @@ def test_default_mode_excludes_incompatible_rewrites():
B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX")


def test_dynamic_sequence_length():
Expand Down
2 changes: 1 addition & 1 deletion tests/link/jax/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):

dot_pt = op(x_pt, y_pt)
fgraph = FunctionGraph(inputs, [dot_pt])
compare_jax_and_py(fgraph, test_values)
compare_jax_and_py(fgraph, test_values, jax_mode="JAX")


def test_sparse_dot_non_const_raises():
Expand Down
2 changes: 1 addition & 1 deletion tests/link/jax/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_arange_of_shape():
x = vector("x")
out = ptb.arange(1, x.shape[-1], 2)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [np.zeros((5,))])
compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX")


def test_arange_nonconcrete():
Expand Down
Loading
Loading