We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The following code
import jax from jax import numpy as jnp !pip install equinox import equinox as eqx print(f"{eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=}") print(f"{jax.hessian(jax.nn.relu)(jnp.ones(()))=}")
results in
eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=Array(1., dtype=float32) jax.hessian(jax.nn.relu)(jnp.ones(()))=Array(0., dtype=float32)
Here, the result of jax.hessian is correct but that of eqx.filter_hessian is not.
jax.hessian
eqx.filter_hessian
I tested this code in Google Colab. The version of Equinox is 0.11.4, that of JAX is 0.4.26, that of Python is 3.10.12.
The text was updated successfully, but these errors were encountered:
Looks like this is fixed on main, I see 0.0 as the output for jax and equinox on main. I believe this was a similar error in the jacfwd to this: #734.
Sorry, something went wrong.
No branches or pull requests
The following code
results in
Here, the result of
jax.hessian
is correct but that ofeqx.filter_hessian
is not.I tested this code in Google Colab. The version of Equinox is 0.11.4, that of JAX is 0.4.26, that of Python is 3.10.12.
The text was updated successfully, but these errors were encountered: