Skip to content

Pure Callbacks do not support JVP. Is there a workaround? #23082

Answered by jakevdp
kgierach-hf asked this question in Q&A
Discussion options

You must be logged in to vote

There are two problems here I think:

(1) you define cwh as a custom_jvp, but you never define its JVP via cwh.defjvp. Fortunately you define its JVP elsewhere, so you can delete cwh and just use compute_weights_host directly.
(2) you use compute_weights_wrapper in a setting where it is automatically differentiated, but it calls jax.pure_callback which is not differentiable. You'll need to define a custom_jvp rule for compute_weights_wrapper to avoid trying to automatically differentiate the pure_callback.

You might find this doc helpful: it contains a worked example of a pure callback using custom_jvp to define autodiff rules: https://jax.readthedocs.io/en/latest/notebooks/external_callba…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@kgierach-hf
Comment options

@kgierach-hf
Comment options

@jakevdp
Comment options

Answer selected by kgierach-hf
@kgierach-hf
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants