diff --git a/Project.toml b/Project.toml index 8eb30c966..6d1292c7e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.41" +version = "0.5.42" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/README.md b/README.md index 2cb4bca6e..9cf60629b 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,8 @@ The 🔥 Deep Learning Framework ## Installation ```julia -] add Lux +import Pkg +Pkg.add("Lux") ``` ## Getting Started diff --git a/docs/Project.toml b/docs/Project.toml index edecd9885..55cbc2e11 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 0c8aa7dc5..9646e0c0f 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -60,6 +60,13 @@ StatefulLuxLayer @compact ``` +## AutoDiff Helpers + +```@docs +jacobian_vector_product +vector_jacobian_product +``` + ## Truncated Stacktraces ```@docs diff --git a/docs/src/manual/nested_autodiff.md b/docs/src/manual/nested_autodiff.md index 3b048fc15..e1906eb7e 100644 --- a/docs/src/manual/nested_autodiff.md +++ b/docs/src/manual/nested_autodiff.md @@ -22,7 +22,7 @@ Let's explore this using some questions that were posted on the [Julia Discourse forum](https://discourse.julialang.org/). ```@example nested_ad -using Lux, LinearAlgebra, Zygote, ForwardDiff, Random +using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random using ComponentArrays, FiniteDiff ``` @@ -42,17 +42,19 @@ work: - Currently we have custom routines implemented for: - `Zygote.` - `ForwardDiff.` + - [`vector_jacobian_product`](@ref) + - [`jacobian_vector_product`](@ref) - Switching only happens for `ChainRules` compatible AD libraries. -We plan to capture `DifferentiationInterface`, `Zygote.pullback`, and `Enzyme.autodiff` -calls in the future (PRs are welcome). +We plan to capture `DifferentiationInterface`, and `Enzyme.autodiff` calls in the +future (PRs are welcome). !!! tip [`@compact`](@ref) uses [`StatefulLuxLayer`](@ref)s internally, so you can directly use these features inside a layer generated by [`@compact`](@ref). -## Nested AD for Neural Differential Equations (DEs) +## Loss Function containing Jacobian Computation This problem comes from `@facusapienza` on [Discourse](https://discourse.julialang.org/t/nested-and-different-ad-methods-altogether-how-to-add-ad-calculations-inside-my-loss-function-when-using-neural-differential-equations/108985). In this case, we want to add a regularization term to the neural DE based on first-order @@ -103,7 +105,7 @@ nothing; # hide That's pretty good, of course you will have some error from the finite differences calculation. -## Loss Function contains Gradient Calculation +## Loss Function contains Gradient Computation Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs on [Discourse](https://discourse.julialang.org/t/is-it-possible-to-do-nested-ad-elegantly-in-julia-pinns/98888/21). @@ -184,3 +186,120 @@ println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, @assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-3 # hide nothing; # hide ``` + +## Hutchinson Trace Estimation + +Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast +estimate of the trace of a Jacobian Matrix. This is based off of +[Hutchinson 1990](https://www.researchgate.net/publication/243668757_A_Stochastic_Estimator_of_the_Trace_of_the_Influence_Matrix_for_Laplacian_Smoothing_Splines) which +computes the estimated trace of a matrix ``A \in \mathbb{R}^{D \times D}`` using random +vectors ``v \in \mathbb{R}^{D}`` s.t. ``\mathbb{E}\left[v v^T\right] = I``. + +```math +\text{Tr}(A) = \mathbb{E}\left[v^T A v\right] = \frac{1}{V} \sum_{i = 1}^V v_i^T A v_i +``` + +We can use this to compute the trace of a Jacobian Matrix ``J \in \mathbb{R}^{D \times D}`` +using the following algorithm: + +```math +\text{Tr}(J) = \frac{1}{V} \sum_{i = 1}^V v_i^T J v_i +``` + +Note that we can compute this using two methods: + +1. Compute ``v_i^T J`` using a Vector-Jacobian product and then do a matrix-vector product + to get the trace. +2. Compute ``J v_i`` using a Jacobian-Vector product and then do a matrix-vector product to + get the trace. + +For simplicity, we will use a single sample of ``v_i`` to compute the trace. Additionally, +we will fix the sample to ensure that our tests against the finite difference implementation +are not affected by the randomness in the sample. + +### Computing using the Vector-Jacobian Product + +```@example nested_ad +function hutchinson_trace_vjp(model, x, ps, st, v) + smodel = StatefulLuxLayer(model, ps, st) + vjp = vector_jacobian_product(smodel, AutoZygote(), x, v) + # ⊠ is the shorthand for `NNlib.batched_mul` + return sum(reshape(vjp, 1, :, size(vjp, ndims(vjp))) ⊠ + reshape(v, :, 1, size(v, ndims(v)))) +end +``` + +This vjp version will be the fastest and most scalable and hence is the recommended way for +computing hutchinson trace. + +### Computing using the Jacobian-Vector Product + +```@example nested_ad +function hutchinson_trace_jvp(model, x, ps, st, v) + smodel = StatefulLuxLayer(model, ps, st) + jvp = jacobian_vector_product(smodel, AutoForwardDiff(), x, v) + # ⊠ is the shorthand for `NNlib.batched_mul` + return sum(reshape(v, 1, :, size(v, ndims(v))) ⊠ + reshape(jvp, :, 1, size(jvp, ndims(jvp)))) +end +``` + +### Computing using the Full Jacobian + +This is definitely not recommended, but we are showing it for completeness. + +```@example nested_ad +function hutchinson_trace_full_jacobian(model, x, ps, st, v) + smodel = StatefulLuxLayer(model, ps, st) + J = ForwardDiff.jacobian(smodel, x) + return vec(v)' * J * vec(v) +end +``` + +Now let's compute the trace and compare the results: + +```@example nested_ad +model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), + Dense(12 => 4)) +ps, st = Lux.setup(Xoshiro(0), model) +x = rand(Xoshiro(0), Float32, 4, 12) +v = (rand(Xoshiro(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample +nothing; # hide +``` + +```@example nested_ad +tr_vjp = hutchinson_trace_vjp(model, x, ps, st, v) +tr_jvp = hutchinson_trace_jvp(model, x, ps, st, v) +tr_full_jacobian = hutchinson_trace_full_jacobian(model, x, ps, st, v) +println("Tr(J) using vjp: ", tr_vjp) +println("Tr(J) using jvp: ", tr_jvp) +println("Tr(J) using full jacobian: ", tr_full_jacobian) +@assert tr_vjp ≈ tr_jvp ≈ tr_full_jacobian # hide +nothing; # hide +``` + +Now that we have verified that the results are the same, let's try to differentiate the +trace estimate. This often shows up as a regularization term in neural networks. + +```@example nested_ad +_, ∂x_vjp, ∂ps_vjp, _, _ = Zygote.gradient(hutchinson_trace_vjp, model, x, ps, st, v) +_, ∂x_jvp, ∂ps_jvp, _, _ = Zygote.gradient(hutchinson_trace_jvp, model, x, ps, st, v) +_, ∂x_full_jacobian, ∂ps_full_jacobian, _, _ = Zygote.gradient(hutchinson_trace_full_jacobian, + model, x, ps, st, v) +nothing; # hide +``` + +For sanity check, let's verify that the gradients are the same: + +```@example nested_ad +println("∂x using vjp: ", norm(∂x_vjp .- ∂x_jvp, Inf)) +println("∂ps using vjp: ", norm(ComponentArray(∂ps_vjp) .- ComponentArray(∂ps_jvp), Inf)) +println("∂x using full jacobian: ", norm(∂x_full_jacobian .- ∂x_vjp, Inf)) +println("∂ps using full jacobian: ", + norm(ComponentArray(∂ps_full_jacobian) .- ComponentArray(∂ps_vjp), Inf)) +@assert norm(∂x_vjp .- ∂x_jvp, Inf) < 1e-3 # hide +@assert norm(ComponentArray(∂ps_vjp) .- ComponentArray(∂ps_jvp), Inf) < 1e-3 # hide +@assert norm(∂x_full_jacobian .- ∂x_vjp, Inf) < 1e-3 # hide +@assert norm(ComponentArray(∂ps_full_jacobian) .- ComponentArray(∂ps_vjp), Inf) < 1e-3 # hide +nothing; # hide +``` diff --git a/ext/LuxForwardDiffExt.jl b/ext/LuxForwardDiffExt.jl index 9e79c947e..3ca1b2332 100644 --- a/ext/LuxForwardDiffExt.jl +++ b/ext/LuxForwardDiffExt.jl @@ -1,5 +1,6 @@ module LuxForwardDiffExt +using ADTypes: AutoForwardDiff using ChainRulesCore: ChainRulesCore using Lux: Lux using FastClosures: @closure @@ -10,7 +11,9 @@ const CRC = ChainRulesCore @inline Lux._is_extension_loaded(::Val{:ForwardDiff}) = true +# Low-Level functions @inline function Lux.__partials(::Type{Tag}, x, i) where {Tag} + x isa ForwardDiff.Dual && return ForwardDiff.partials(Tag, x, i) x isa AbstractArray && return ForwardDiff.partials.(Tag, x, i) map_fn = @closure(xᵢ->Lux.__partials(Tag, xᵢ, i)) x isa Tuple && return map(map_fn, x) @@ -20,15 +23,63 @@ const CRC = ChainRulesCore return fmap(map_fn, x) end +@inline function Lux.__dualify(::Type{Tag}, ::Type{T}, x, u) where {Tag, T} + if x isa AbstractArray + return ForwardDiff.Dual{ + Tag, T, 1}.(x, ForwardDiff.Partials{1, T}.(tuple.(reshape(u, size(x))))) + end + x isa Tuple && return map((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u) + x isa NamedTuple && + return NamedTuple{keys(x)}(map((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u)) + return fmap((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u) +end + # This is not a general jvp code, but rather meant to be efficient for nested AD calls -function Lux.__forwarddiff_jvp( - f::F, x::AbstractArray{xT}, Δx::AbstractArray{ΔxT}, ps) where {F, xT, ΔxT} - T = promote_type(xT, ΔxT) +function Lux.__forwarddiff_jvp(f::F, x, Δx, y) where {F} + T = promote_type(Lux.__recursive_eltype(x), Lux.__recursive_eltype(Δx)) Tag = typeof(ForwardDiff.Tag(f, T)) - partials = ForwardDiff.Partials{1, T}.(tuple.(Δx)) - x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x))) - y_dual, ps_dual = f(x_dual, ps) - return Lux.__partials(Tag, y_dual, 1), Lux.__partials(Tag, ps_dual, 1) + res1_dual, res2_dual = f(Lux.__dualify(Tag, T, x, Δx), y) + return (Lux.__partials(Tag, res1_dual, 1), Lux.__partials(Tag, res2_dual, 1)) +end + +# jvp +function Lux.__jacobian_vector_product_impl(f::F, ::AutoForwardDiff, x, u) where {F} + T = promote_type(Lux.__recursive_eltype(x), Lux.__recursive_eltype(u)) + Tag = typeof(ForwardDiff.Tag(f, T)) + y_dual = f(Lux.__dualify(Tag, T, x, u)) + return Lux.__partials(Tag, y_dual, 1) +end + +function __jacobian_vector_product_ad_impl(f::F, x, u, y) where {F} + return Lux.__jacobian_vector_product_impl(Base.Fix2(f, y), AutoForwardDiff(), x, u) +end + +for fType in Lux.AD_CONVERTIBLE_FUNCTIONS + @eval @inline function Lux.__jacobian_vector_product_impl( + f::$(fType), ::AutoForwardDiff, x, u) + f_internal, y = Lux.__rewrite_ad_call(f) + return __jacobian_vector_product_ad_impl(f_internal, x, u, y) + end +end + +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__jacobian_vector_product_ad_impl), f::F, x, u, y) where {F} + res = __jacobian_vector_product_ad_impl(f, x, u, y) + + pullback_fn = (f_internal, x, args...) -> begin + res, ∂f = CRC.rrule_via_ad(cfg, f_internal, x, args...) + ∂f_internal(Δ) = ∂f(Δ)[2:end] + return res, ∂f_internal + end + + ∇internal_nested_pushforward_capture = Δ -> begin + _, pb_f = CRC.rrule_via_ad( + cfg, Lux.__internal_ad_pullback_call, pullback_fn, f, x, y, Δ) + _, _, _, ∂x, ∂y, _ = pb_f(u) + return CRC.NoTangent(), CRC.NoTangent(), ∂x, CRC.NoTangent(), ∂y + end + + return res, ∇internal_nested_pushforward_capture end # Capture ForwardDiff.jacobian call and replace it with forward over reverse mode AD @@ -49,8 +100,8 @@ for fType in Lux.AD_CONVERTIBLE_FUNCTIONS, type in (:Gradient, :Jacobian) @inline function ForwardDiff.$(fname)(f::$fType, x::AbstractArray, cfg::ForwardDiff.$(cfgname)=ForwardDiff.$(cfgname)(f, x), chk::Val=Val(true)) - f_internal, ps = Lux.__rewrite_ad_call(f) - return $(internal_fname)(f_internal, cfg, chk, x, ps) + f_internal, y = Lux.__rewrite_ad_call(f) + return $(internal_fname)(f_internal, cfg, chk, x, y) end end end diff --git a/ext/LuxZygoteExt.jl b/ext/LuxZygoteExt.jl index 516bf5e38..535b88219 100644 --- a/ext/LuxZygoteExt.jl +++ b/ext/LuxZygoteExt.jl @@ -8,6 +8,8 @@ using Zygote: Zygote const CRC = ChainRulesCore +Lux._is_extension_loaded(::Val{:Zygote}) = true + function Lux.Experimental.compute_gradients(::AutoZygote, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} (loss, st, stats), back = Zygote.pullback( @@ -17,18 +19,28 @@ function Lux.Experimental.compute_gradients(::AutoZygote, objective_function::F, return grads, loss, stats, ts end +function Lux.__vector_jacobian_product_impl(f::F, ::AutoZygote, x, u) where {F} + _, pb_f = Zygote.pullback(f, x) + return only(pb_f(u)) +end + # Nested AD Handling for fType in Lux.AD_CONVERTIBLE_FUNCTIONS @eval begin @inline function Zygote.gradient(f::$fType, x) - f_internal, ps = Lux.__rewrite_ad_call(f) - return Lux.__internal_ad_gradient_call(Zygote.gradient, f_internal, x, ps) + f_internal, y = Lux.__rewrite_ad_call(f) + return Lux.__internal_ad_gradient_call(Zygote.gradient, f_internal, x, y) end @inline function Zygote.jacobian(f::$fType, x::AbstractArray) - f_internal, ps = Lux.__rewrite_ad_call(f) + f_internal, y = Lux.__rewrite_ad_call(f) return Lux.__internal_ad_jacobian_call( - Zygote.jacobian, Zygote.gradient, f_internal, x, ps) + Zygote.jacobian, Zygote.gradient, f_internal, x, y) + end + + @inline function Lux.__vector_jacobian_product_impl(f::$fType, ::AutoZygote, x, u) + f_internal, y = Lux.__rewrite_ad_call(f) + return Lux.__internal_ad_pullback_call(Zygote.pullback, f_internal, x, y, u) end end end diff --git a/src/Lux.jl b/src/Lux.jl index 95e79eed7..8feffacae 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -3,6 +3,7 @@ module Lux using PrecompileTools: @recompile_invalidations @recompile_invalidations begin + using ADTypes: AbstractADType, AutoForwardDiff, AutoZygote using Adapt: Adapt, adapt using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore, AbstractZero, HasReverseMode, NoTangent, @@ -69,6 +70,7 @@ include("contrib/contrib.jl") # Helpful Functionalities include("helpers/stateful.jl") include("helpers/compact.jl") +include("helpers/autodiff.jl") include("helpers/nested_ad.jl") # Transform to and from other frameworks @@ -99,6 +101,7 @@ export SamePad, TimeLastIndex, BatchLastIndex export StatefulLuxLayer export @compact, CompactLuxLayer +export jacobian_vector_product, vector_jacobian_product export f16, f32, f64 diff --git a/src/helpers/autodiff.jl b/src/helpers/autodiff.jl new file mode 100644 index 000000000..98ee35a5d --- /dev/null +++ b/src/helpers/autodiff.jl @@ -0,0 +1,75 @@ +@doc doc""" + vector_jacobian_product(f, backend::AbstractADType, x, u) + +Compute the Vector Jacobian Product ``\left(\frac{\partial f}{\partial x}\right)^T u``. +This is a wrapper around AD backends but allows us to compute gradients of vector-jacobian +products efficiently using mixed-mode AD. + +The following backends are supported: + + - `AutoZygote`: `Zygote.jl` must be loaded. + +!!! warning + + Gradient wrt `u` in the reverse pass is always dropped. + +## Arguments + + - `f`: The function to compute the jacobian of. + - `backend`: The backend to use for computing the VJP. + - `x`: The input to the function. + - `u`: An object of the same structure as `f(x)`. + +## Returns + + - `v`: The Vector Jacobian Product. +""" +function vector_jacobian_product(f::F, backend::AbstractADType, x, u) where {F} + @assert backend isa AutoZygote "Only `AutoZygote` is supported for \ + `vector_jacobian_product`." + if !_is_extension_loaded(Val(:Zygote)) + error("`Zygote.jl` must be loaded for `vector_jacobian_product` \ + to work with `$(backend)`.") + end + return __vector_jacobian_product_impl(f, backend, x, u) +end + +function __vector_jacobian_product_impl end + +@doc doc""" + jacobian_vector_product(f, backend::AbstractADType, x, u) + +Compute the Vector Jacobian Product ``\left(\frac{\partial f}{\partial x}\right) u``. +This is a wrapper around AD backends but allows us to compute gradients of jacobian-vector +products efficiently using mixed-mode AD. + +The following packages must be loaded for this function to work: + + - `AutoForwardDiff`: `ForwardDiff.jl` must be loaded. + +!!! warning + + Gradient wrt `u` in the reverse pass is always dropped. + +## Arguments + + - `f`: The function to compute the jacobian of. + - `backend`: The backend to use for computing the JVP. + - `x`: The input to the function. + - `u`: An object of the same structure as `x`. + +## Returns + + - `v`: The Jacobian Vector Product. +""" +function jacobian_vector_product(f::F, backend::AbstractADType, x, u) where {F} + @assert backend isa AutoForwardDiff "Only `AutoForwardDiff` is supported for \ + `jacobian_vector_product`." + if !_is_extension_loaded(Val(:ForwardDiff)) + error("`ForwardDiff.jl` must be loaded for `jacobian_vector_product` \ + to work with `$(backend)`.") + end + return __jacobian_vector_product_impl(f, backend, x, u) +end + +function __jacobian_vector_product_impl end diff --git a/src/helpers/nested_ad.jl b/src/helpers/nested_ad.jl index 6622adc9b..65a8d2610 100644 --- a/src/helpers/nested_ad.jl +++ b/src/helpers/nested_ad.jl @@ -1,6 +1,7 @@ function __forwarddiff_jvp end # Defined in ForwardDiff.jl extension function __partials end # DON'T REMOVE THIS (DEQs.jl is using it) +function __dualify end #! format: off const AD_CONVERTIBLE_FUNCTIONS = [ @@ -32,15 +33,14 @@ const AD_CONVERTIBLE_FUNCTIONS = [ error("Unknown function type: $(typeof(f))") end -# Essentially computes the gradient of `f(x, y)` wrt x using the function `grad_fn` -# To compute the gradient of `f(x, y)` wrt y, just reorder the arguments with a wrapper -# over `f` -@inline function __internal_ad_gradient_call(grad_fn::G, f::F, x, y) where {G, F} - return grad_fn(Base.Fix2(f, y), x) -end -@inline function __internal_ad_gradient_call_no_custom_rrule( - grad_fn::G, f::F, x, y) where {G, F} - return grad_fn(Base.Fix2(f, y), x) # Don' call `__internal_ad_gradient_call` +# Nested Gradients +## Essentially computes the gradient of `f(x, y)` wrt x using the function `grad_fn` +## To compute the gradient of `f(x, y)` wrt y, just reorder the arguments with a wrapper +## over `f` +for fname in (:__internal_ad_gradient_call, :__internal_ad_gradient_call_no_custom_rrule) + @eval @inline function $fname(grad_fn::G, f::F, x, y) where {G, F} + return grad_fn(Base.Fix2(f, y), x) + end end function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, @@ -48,7 +48,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, # Check if we can use the faster implementation if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH - @warn "Load ForwardDiff.jl for better nested AD handling." maxlog=1 + @warn "Load `ForwardDiff.jl` for better nested AD handling." maxlog=1 end # Use the AD itself for whatever reason return CRC.rrule_via_ad( @@ -60,8 +60,8 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, (Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) && return ntuple(Returns(CRC.NoTangent()), 5) - Δ = CRC.backing(CRC.unthunk(Δ_)) - Δ isa Tuple && (Δ = only(Δ)) # For Zygote and such which return a tuple + Δ = CRC.unthunk(Δ_) + (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple ∂x, ∂y = __forwarddiff_jvp(@closure((x, y)->grad_fn(f, x, y)), x, Δ, y) return CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂y end @@ -69,14 +69,58 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, return res, ∇internal_gradient_capture end -# `grad_fn` is not needed for the forward pass, we need it for the reverse pass HVP -function __internal_ad_jacobian_call( - jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} - return jac_fn(Base.Fix2(f, y), x) +# Nested Pullbacks +for fname in (:__internal_ad_pullback_call, :__internal_ad_pullback_call_no_custom_rrule) + @eval @inline function $fname(pullback_fn::P, f::F, x, y, u) where {P, F} + return only(last(pullback_fn(Base.Fix2(f, y), x))(u)) + end end -@inline function __internal_ad_jacobian_call_no_custom_rrule( - jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} - return jac_fn(Base.Fix2(f, y), x) # Don' call `__internal_ad_jacobian_call` + +function CRC.rrule( + cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__internal_ad_pullback_call), + pullback_fn::P, f::F, x, y, u) where {P, F} + # Check if we can use the faster implementation + if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH + if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH + @warn "Load `ForwardDiff.jl` for better nested AD handling." maxlog=1 + end + # Use the AD itself for whatever reason + return CRC.rrule_via_ad( + cfg, __internal_ad_pullback_call_no_custom_rrule, pullback_fn, f, x, y, u) + end + + res = __internal_ad_pullback_call(pullback_fn, f, x, y, u) + ∇internal_pullback_capture = let pullback_fn = pullback_fn, + f = f, + x = x, + y = y, + u = u, + res = res + + Δ_ -> begin + (Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) && + return ntuple(Returns(CRC.NoTangent()), 6) + + Δ = CRC.unthunk(Δ_) + (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple + ∂x, ∂y = __forwarddiff_jvp(x, Δ, y) do x_dual, y_ + return last(pullback_fn(f, x_dual, y_))(u) + end + return ( + CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂y, CRC.NoTangent()) + end + end + + return res, ∇internal_pullback_capture +end + +# Nested Jacobians +## `grad_fn` is not needed for the forward pass, we need it for the reverse pass HVP +for fname in (:__internal_ad_jacobian_call, :__internal_ad_jacobian_call_no_custom_rrule) + @eval @inline function $fname( + jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} + return jac_fn(Base.Fix2(f, y), x) + end end function CRC.rrule( @@ -85,7 +129,7 @@ function CRC.rrule( # Check if we can use the faster implementation if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH - @warn "Load ForwardDiff.jl for better nested AD handling." maxlog=1 + @warn "Load `ForwardDiff.jl` for better nested AD handling." maxlog=1 end # Use the AD itself for whatever reason return CRC.rrule_via_ad( @@ -98,8 +142,8 @@ function CRC.rrule( (Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) && return ntuple(Returns(CRC.NoTangent()), 6) - Δ = CRC.backing(CRC.unthunk(Δ_)) - Δ isa Tuple && (Δ = only(Δ)) # For Zygote and such which return a tuple + Δ = CRC.unthunk(Δ_) + (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple Δ = __compactify_if_structured_matrix(res isa Tuple ? only(res) : res, Δ) # TODO: Here we can potentially chunk the gradients for faster AD calls diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index dcd9f1655..e27dde2da 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -11,26 +11,12 @@ This is meant to be used in internal implementation of layers. ## Usecases - Internal implementation of [`@compact`](@ref) heavily uses this layer. - - In SciML codebases where propagating state might involving [`Box`ing](https://github.com/JuliaLang/julia/issues/15276). For a motivating example, see the Neural ODE tutorial. - - This layer automatically converts `Zygote.gradient(op ∘ model::StatefulLuxLayer, x)` to - a `ForwardDiff.jl` jacobian-vector product over `Zygote.gradient` call. In future, we - will overload `DifferentiationInterface.gradient` and - `DifferentiationInterface.jacobian` calls as well. For this feature to be available, - `ForwardDiff.jl` must be loaded. Additionally this feature is exclusively available - for AD backends supporting ChainRules, so ReverseDiff and Tracker won't make this - automatic conversion. For more details on this feature, see the + - Facilitates Nested AD support in Lux. For more details on this feature, see the [Nested AD Manual Page](@ref nested_autodiff). -!!! tip - - Automatic Nested AD Switching behavior can be disabled by setting the preference - `DisableAutomaticNestedADSwitching` to `true`. See documentation of - [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) and - [PreferenceTools.jl](https://github.com/cjdoris/PreferenceTools.jl) on how to do this. - ## Arguments - `model`: A Lux layer diff --git a/src/utils.jl b/src/utils.jl index 2b2926f2f..88c3b0bec 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -259,3 +259,19 @@ __named_tuple(nt::NamedTuple) = nt @inline _vec(x::AbstractArray) = vec(x) @inline _vec(::Nothing) = nothing + +# recussive_eltype +@inline __recursive_eltype(x::AbstractArray) = eltype(x) +@inline __recursive_eltype(x::Tuple) = promote_type(__recursice_eltype.(x)...) +@inline __recursive_eltype(x::NamedTuple) = promote_type(__recursive_eltype.(values(x))...) +@inline __recursive_eltype(::Nothing) = Bool +@inline __recursive_eltype(x::Number) = eltype(x) +@inline function __recursive_eltype(x) + _eltype = Ref(Bool) + function __internal_recursive_eltype(x) + _eltype[] = promote_type(_eltype[], __recursive_eltype(x)) + return x + end + fmap(__internal_recursive_eltype, x) + return _eltype[] +end diff --git a/test/helpers/nestedad_tests.jl b/test/helpers/nestedad_tests.jl index 2275e289e..a5374e58c 100644 --- a/test/helpers/nestedad_tests.jl +++ b/test/helpers/nestedad_tests.jl @@ -199,3 +199,74 @@ end end end end + +@testitem "Nested AD: VJP & JVP" setup=[SharedTestSetup] tags=[:others] begin + using ComponentArrays, FiniteDifferences, ForwardDiff, LinearAlgebra, Zygote, ADTypes + + Base.isfinite(::Nothing) = true + + rng = get_stable_rng() + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + # FIXME: AMDGPU takes too long right now + mode === "AMDGPU" && continue + + models = ( + Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4), + Conv((3, 3), 4 => 1, gelu; pad=SamePad())), + Chain(Dense(2, 4, gelu), Dense(4, 1))) + Xs = (aType(randn(rng, Float32, 3, 3, 2, 4)), aType(randn(rng, Float32, 2, 4))) + + for (model, X) in zip(models, Xs) + ps, st = Lux.setup(rng, model) |> dev + X = X |> aType + + vjp_input = first(model(X, ps, st)) + jvp_input = aType(randn(rng, Float32, size(X)...)) + + function loss_function_vjp(model, X, ps, st, vjp_input) + smodel = StatefulLuxLayer(model, ps, st) + vjp = vector_jacobian_product(smodel, AutoZygote(), X, vjp_input) + return sum(vjp) + end + + function loss_function_vjp_jacobian(model, X, ps, st, vjp_input) + smodel = StatefulLuxLayer(model, ps, st) + J = only(Zygote.jacobian(smodel, X)) + return sum(J' * vec(vjp_input)) + end + + function loss_function_jvp(model, X, ps, st, jvp_input) + smodel = StatefulLuxLayer(model, ps, st) + jvp = jacobian_vector_product(smodel, AutoForwardDiff(), X, jvp_input) + return sum(jvp) + end + + function loss_function_jvp_jacobian(model, X, ps, st, jvp_input) + smodel = StatefulLuxLayer(model, ps, st) + J = only(Zygote.jacobian(smodel, X)) + return sum(J * vec(jvp_input)) + end + + @test_nowarn loss_function_vjp(model, X, ps, st, vjp_input) + @test loss_function_vjp(model, X, ps, st, vjp_input) isa Number + + _, ∂x, ∂ps, _ = Zygote.gradient(loss_function_vjp, model, X, ps, st, vjp_input) + _, ∂x_vjp, ∂ps_vjp, _, _ = Zygote.gradient( + loss_function_vjp_jacobian, model, X, ps, st, vjp_input) + + @test ∂x≈∂x_vjp rtol=1e-3 atol=1e-3 + @test check_approx(∂ps, ∂ps_vjp; rtol=1e-3, atol=1e-3) + + @test_nowarn loss_function_jvp(model, X, ps, st, jvp_input) + @test loss_function_jvp(model, X, ps, st, jvp_input) isa Number + + _, ∂x, ∂ps, _ = Zygote.gradient(loss_function_jvp, model, X, ps, st, jvp_input) + _, ∂x_jvp, ∂ps_jvp, _, _ = Zygote.gradient( + loss_function_jvp_jacobian, model, X, ps, st, jvp_input) + + @test ∂x≈∂x_jvp rtol=1e-3 atol=1e-3 + @test check_approx(∂ps, ∂ps_jvp; rtol=1e-3, atol=1e-3) + end + end +end