diff --git a/tests/test_katsu_math.py b/tests/test_katsu_math.py index a22361a..09b4bdd 100644 --- a/tests/test_katsu_math.py +++ b/tests/test_katsu_math.py @@ -1,4 +1,3 @@ -import jax.numpy as jnp import pytest from katsu.katsu_math import ( BackendShim, @@ -18,7 +17,7 @@ cupy_installed = False try: - import jax.numpy + import jax.numpy as jnp jax_installed = True except Exception: