Skip to content

Commit

Permalink
changed default for float32 to float64 with jax
Browse files Browse the repository at this point in the history
  • Loading branch information
kenjim21 committed Oct 11, 2024
1 parent f1e36cc commit 165bfb0
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions katsu/katsu_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ def set_backend_to_cupy():

def set_backend_to_jax():
"""Convenience method to automatically configure katsu's backend to jax."""
import jax.numpy as jnp
import jax as jax

np._srcmodule = jnp
jax.config.update("jax_enable_x64", True)
np._srcmodule = jax.numpy

return

Expand Down

0 comments on commit 165bfb0

Please sign in to comment.