Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vector_jacobian_product and jacobian_vector_product functions #623

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.5.41"
version = "0.5.42"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ The 🔥 Deep Learning Framework
## Installation

```julia
] add Lux
import Pkg
Pkg.add("Lux")
```

## Getting Started
Expand Down
29 changes: 22 additions & 7 deletions ext/LuxForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module LuxForwardDiffExt

using ADTypes: AutoForwardDiff
using ChainRulesCore: ChainRulesCore
using Lux: Lux
using FastClosures: @closure
Expand All @@ -10,6 +11,7 @@ const CRC = ChainRulesCore

@inline Lux._is_extension_loaded(::Val{:ForwardDiff}) = true

# Low-Level functions
@inline function Lux.__partials(::Type{Tag}, x, i) where {Tag}
x isa AbstractArray && return ForwardDiff.partials.(Tag, x, i)
map_fn = @closure(xᵢ->Lux.__partials(Tag, xᵢ, i))
Expand All @@ -20,15 +22,28 @@ const CRC = ChainRulesCore
return fmap(map_fn, x)
end

@inline function Lux.__dualify(::Type{Tag}, ::Type{T}, x, u) where {Tag, T}
if x isa AbstractArray
return ForwardDiff.Dual{
Tag, T, 1}.(x, ForwardDiff.Partials{1, T}.(tuple.(reshape(u, size(x)))))
end
x isa Tuple && return map((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u)
x isa NamedTuple &&
return NamedTuple{keys(x)}(map((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u))
return fmap((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u)
end

# This is not a general jvp code, but rather meant to be efficient for nested AD calls
function Lux.__forwarddiff_jvp(
f::F, x::AbstractArray{xT}, Δx::AbstractArray{ΔxT}, ps) where {F, xT, ΔxT}
T = promote_type(xT, ΔxT)
function Lux.__forwarddiff_jvp(f::F, x, Δx, args...) where {F}
T = promote_type(Lux.__recursive_eltype(x), Lux.__recursive_eltype(Δx))
Tag = typeof(ForwardDiff.Tag(f, T))
partials = ForwardDiff.Partials{1, T}.(tuple.(Δx))
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
y_dual, ps_dual = f(x_dual, ps)
return Lux.__partials(Tag, y_dual, 1), Lux.__partials(Tag, ps_dual, 1)
y_dual, args_duals... = f(Lux.__dualify(Tag, T, x, Δx), args...)
return (Lux.__partials(Tag, y_dual, 1), Lux.__partials.((Tag,), args_duals, 1)...)
end

# jvp
function Lux.__jacobian_vector_product_impl(f::F, ::AutoForwardDiff, x, u) where {F}
return only(Lux.__forwarddiff_jvp(f, x, u))
end

# Capture ForwardDiff.jacobian call and replace it with forward over reverse mode AD
Expand Down
20 changes: 16 additions & 4 deletions ext/LuxZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using Zygote: Zygote

const CRC = ChainRulesCore

Lux._is_extension_loaded(::Val{:Zygote}) = true

function Lux.Experimental.compute_gradients(::AutoZygote, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
(loss, st, stats), back = Zygote.pullback(
Expand All @@ -17,18 +19,28 @@ function Lux.Experimental.compute_gradients(::AutoZygote, objective_function::F,
return grads, loss, stats, ts
end

function Lux.__vector_jacobian_product_impl(f::F, ::AutoZygote, x, u) where {F}
_, pb_f = Zygote.pullback(f, x)
return only(pb_f(u))
end

# Nested AD Handling
for fType in Lux.AD_CONVERTIBLE_FUNCTIONS
@eval begin
@inline function Zygote.gradient(f::$fType, x)
f_internal, ps = Lux.__rewrite_ad_call(f)
return Lux.__internal_ad_gradient_call(Zygote.gradient, f_internal, x, ps)
f_internal, y = Lux.__rewrite_ad_call(f)
return Lux.__internal_ad_gradient_call(Zygote.gradient, f_internal, x, y)
end

@inline function Zygote.jacobian(f::$fType, x::AbstractArray)
f_internal, ps = Lux.__rewrite_ad_call(f)
f_internal, y = Lux.__rewrite_ad_call(f)
return Lux.__internal_ad_jacobian_call(
Zygote.jacobian, Zygote.gradient, f_internal, x, ps)
Zygote.jacobian, Zygote.gradient, f_internal, x, y)
end

@inline function Lux.__vector_jacobian_product_impl(f::$fType, ::AutoZygote, x, u)
f_internal, y = Lux.__rewrite_ad_call(f)
return Lux.__internal_ad_pullback_call(Zygote.pullback, f_internal, x, y, u)
end
end
end
Expand Down
3 changes: 3 additions & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Lux
using PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes: AbstractADType, AutoForwardDiff, AutoZygote
using Adapt: Adapt, adapt
using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore, AbstractZero, HasReverseMode, NoTangent,
Expand Down Expand Up @@ -69,6 +70,7 @@ include("contrib/contrib.jl")
# Helpful Functionalities
include("helpers/stateful.jl")
include("helpers/compact.jl")
include("helpers/autodiff.jl")
include("helpers/nested_ad.jl")

# Transform to and from other frameworks
Expand Down Expand Up @@ -99,6 +101,7 @@ export SamePad, TimeLastIndex, BatchLastIndex

export StatefulLuxLayer
export @compact, CompactLuxLayer
export jacobian_vector_product, vector_jacobian_product

export f16, f32, f64

Expand Down
75 changes: 75 additions & 0 deletions src/helpers/autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
@doc doc"""
vector_jacobian_product(f, backend::AbstractADType, x, u)

Compute the Vector Jacobian Product ``\left(\frac{\partial f}{\partial x}\right)^T u``.
This is a wrapper around AD backends but allows us to compute gradients of vector-jacobian
products efficiently using mixed-mode AD.

The following backends are supported:

- `AutoZygote`: `Zygote.jl` must be loaded.

!!! warning

Gradient wrt `u` in the reverse pass is always dropped.

## Arguments

- `f`: The function to compute the jacobian of.
- `backend`: The backend to use for computing the VJP.
- `x`: The input to the function.
- `u`: An object of the same structure as `f(x)`.

## Returns

- `v`: The Vector Jacobian Product.
"""
function vector_jacobian_product(f::F, backend::AbstractADType, x, u) where {F}
@assert backend isa AutoZygote "Only `AutoZygote` is supported for \
`vector_jacobian_product`."
if !_is_extension_loaded(Val(:Zygote))
error("`Zygote.jl` must be loaded for `vector_jacobian_product` \
to work with `$(backend)`.")
end
return __vector_jacobian_product_impl(f, backend, x, u)
end

function __vector_jacobian_product_impl end

@doc doc"""
jacobian_vector_product(f, backend::AbstractADType, x, u)

Compute the Vector Jacobian Product ``\left(\frac{\partial f}{\partial x}\right) u``.
This is a wrapper around AD backends but allows us to compute gradients of jacobian-vector
products efficiently using mixed-mode AD.

The following packages must be loaded for this function to work:

- `AutoForwardDiff`: `ForwardDiff.jl` must be loaded.

!!! warning

Gradient wrt `u` in the reverse pass is always dropped.

## Arguments

- `f`: The function to compute the jacobian of.
- `backend`: The backend to use for computing the JVP.
- `x`: The input to the function.
- `u`: An object of the same structure as `x`.

## Returns

- `v`: The Jacobian Vector Product.
"""
function jacobian_vector_product(f::F, backend::AbstractADType, x, u) where {F}
@assert backend isa AutoForwardDiff "Only `AutoForwardDiff` is supported for \
`jacobian_vector_product`."
if !_is_extension_loaded(Val(:ForwardDiff))
error("`ForwardDiff.jl` must be loaded for `jacobian_vector_product` \
to work with `$(backend)`.")
end
return __jacobian_vector_product_impl(f, backend, x, u)
end

function __jacobian_vector_product_impl end
79 changes: 59 additions & 20 deletions src/helpers/nested_ad.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function __forwarddiff_jvp end # Defined in ForwardDiff.jl extension

function __partials end # DON'T REMOVE THIS (DEQs.jl is using it)
function __dualify end

#! format: off
const AD_CONVERTIBLE_FUNCTIONS = [
Expand Down Expand Up @@ -32,23 +33,22 @@ const AD_CONVERTIBLE_FUNCTIONS = [
error("Unknown function type: $(typeof(f))")
end

# Essentially computes the gradient of `f(x, y)` wrt x using the function `grad_fn`
# To compute the gradient of `f(x, y)` wrt y, just reorder the arguments with a wrapper
# over `f`
@inline function __internal_ad_gradient_call(grad_fn::G, f::F, x, y) where {G, F}
return grad_fn(Base.Fix2(f, y), x)
end
@inline function __internal_ad_gradient_call_no_custom_rrule(
grad_fn::G, f::F, x, y) where {G, F}
return grad_fn(Base.Fix2(f, y), x) # Don' call `__internal_ad_gradient_call`
# Nested Gradients
## Essentially computes the gradient of `f(x, y)` wrt x using the function `grad_fn`
## To compute the gradient of `f(x, y)` wrt y, just reorder the arguments with a wrapper
## over `f`
for fname in (:__internal_ad_gradient_call, :__internal_ad_gradient_call_no_custom_rrule)
@eval @inline function $fname(grad_fn::G, f::F, x, y) where {G, F}
return grad_fn(Base.Fix2(f, y), x)
end
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__internal_ad_gradient_call), grad_fn::G, f::F, x, y) where {G, F}
# Check if we can use the faster implementation
if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH
if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH
@warn "Load ForwardDiff.jl for better nested AD handling." maxlog=1
@warn "Load `ForwardDiff.jl` for better nested AD handling." maxlog=1
end
# Use the AD itself for whatever reason
return CRC.rrule_via_ad(
Expand All @@ -60,7 +60,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
(Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 5)

Δ = CRC.backing(CRC.unthunk(Δ_))
Δ = CRC.unthunk(Δ_)
Δ isa Tuple && (Δ = only(Δ)) # For Zygote and such which return a tuple
∂x, ∂y = __forwarddiff_jvp(@closure((x, y)->grad_fn(f, x, y)), x, Δ, y)
return CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂y
Expand All @@ -69,14 +69,53 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
return res, ∇internal_gradient_capture
end

# `grad_fn` is not needed for the forward pass, we need it for the reverse pass HVP
function __internal_ad_jacobian_call(
jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F}
return jac_fn(Base.Fix2(f, y), x)
# Nested Pullbacks
for fname in (:__internal_ad_pullback_call, :__internal_ad_pullback_call_no_custom_rrule)
@eval @inline function $fname(pullback_fn::P, f::F, x, y, u) where {P, F}
return only(last(pullback_fn(Base.Fix2(f, y), x))(u))
end
end
@inline function __internal_ad_jacobian_call_no_custom_rrule(
jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F}
return jac_fn(Base.Fix2(f, y), x) # Don' call `__internal_ad_jacobian_call`

function CRC.rrule(
cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(__internal_ad_pullback_call),
pullback_fn::P, f::F, x, y, u) where {P, F}
# Check if we can use the faster implementation
if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH
if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH
@warn "Load `ForwardDiff.jl` for better nested AD handling." maxlog=1
end
# Use the AD itself for whatever reason
return CRC.rrule_via_ad(
cfg, __internal_ad_pullback_call_no_custom_rrule, pullback_fn, f, x, y, u)
end

res = __internal_ad_pullback_call(pullback_fn, f, x, y, u)
∇internal_pullback_capture = let pullback_fn = pullback_fn, f = f, x = x, y = y, u = u
Δ_ -> begin
(Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 6)

Δ = CRC.unthunk(Δ_)
Δ isa Tuple && (Δ = only(Δ)) # For Zygote and such which return a tuple
∂x, ∂y = __forwarddiff_jvp(x, Δ, y) do x_dual, y_
_, pb_f = pullback_fn(f, x_dual, y_)
return pb_f(u)
end
return (
CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂y, CRC.NoTangent())
end
end

return res, ∇internal_pullback_capture
end

# Nested Jacobians
## `grad_fn` is not needed for the forward pass, we need it for the reverse pass HVP
for fname in (:__internal_ad_jacobian_call, :__internal_ad_jacobian_call_no_custom_rrule)
@eval @inline function $fname(
jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F}
return jac_fn(Base.Fix2(f, y), x)
end
end

function CRC.rrule(
Expand All @@ -85,7 +124,7 @@ function CRC.rrule(
# Check if we can use the faster implementation
if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH
if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH
@warn "Load ForwardDiff.jl for better nested AD handling." maxlog=1
@warn "Load `ForwardDiff.jl` for better nested AD handling." maxlog=1
end
# Use the AD itself for whatever reason
return CRC.rrule_via_ad(
Expand All @@ -98,7 +137,7 @@ function CRC.rrule(
(Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 6)

Δ = CRC.backing(CRC.unthunk(Δ_))
Δ = CRC.unthunk(Δ_)
Δ isa Tuple && (Δ = only(Δ)) # For Zygote and such which return a tuple
Δ = __compactify_if_structured_matrix(res isa Tuple ? only(res) : res, Δ)

Expand Down
16 changes: 16 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,19 @@ __named_tuple(nt::NamedTuple) = nt

@inline _vec(x::AbstractArray) = vec(x)
@inline _vec(::Nothing) = nothing

# recussive_eltype
@inline __recursive_eltype(x::AbstractArray) = eltype(x)
@inline __recursive_eltype(x::Tuple) = promote_type(__recursice_eltype.(x)...)
@inline __recursive_eltype(x::NamedTuple) = promote_type(__recursive_eltype.(values(x))...)
@inline __recursive_eltype(::Nothing) = Bool
@inline __recursive_eltype(x::Number) = eltype(x)
@inline function __recursive_eltype(x)
_eltype = Ref(Bool)
function __internal_recursive_eltype(x)
_eltype[] = promote_type(_eltype[], __recursive_eltype(x))
return x
end
fmap(__internal_recursive_eltype, x)
return _eltype[]
end
Loading