Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested AD for Parameter Gradient/Jacobian #610

Closed
prbzrg opened this issue Apr 29, 2024 · 7 comments · Fixed by #612
Closed

Nested AD for Parameter Gradient/Jacobian #610

prbzrg opened this issue Apr 29, 2024 · 7 comments · Fixed by #612

Comments

@prbzrg
Copy link
Contributor

prbzrg commented Apr 29, 2024

I'm getting this error for a code that one month ago was working:

ERROR: LoadError: MethodError: no method matching fast_materialize(::Static.False, ::Static.False, ::Matrix{Float32})

Closest candidates are:
  fast_materialize(::SB, ::DB, ::Base.Broadcast.Broadcasted{S}) where {S, SB, DB}
   @ FastBroadcast C:\Users\prbzr\.julia\packages\FastBroadcast\ux5mz\src\FastBroadcast.jl:22

Stacktrace:
   [1] macro expansion
     @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0 [inlined]
   [2] _pullback(::Zygote.Context{false}, ::typeof(FastBroadcast.fast_materialize), ::Static.False, ::Static.False, ::Matrix{Float32})
     @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:81
   [3] __activation_gradient
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\utils.jl:187 [inlined]
   [4] LuxDL/LuxLib.jl#44
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\impl\fused_dense.jl:45 [inlined]
   [5] _pullback(ctx::Zygote.Context{false}, f::LuxLib.var"#44#47"{typeof(tanh_fast), typeof(identity), Matrix{Float32}, Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Matrix{Float32}, Nothing}, args::Matrix{Float32})
     @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
   [6] ZBack
     @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
   [7] Pullback
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:46 [inlined]
   [8] Pullback
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:31 [inlined]
   [9] Pullback
     @ C:\Users\prbzr\.julia\packages\Lux\ANzxX\src\layers\basic.jl:218 [inlined]
  [10] Pullback
     @ C:\Users\prbzr\.julia\packages\LuxCore\8lRV2\src\LuxCore.jl:180 [inlined]
  [b2108857] Lux v0.5.40
  [82251201] LuxLib v0.3.18
  [bb33d45b] LuxCore v0.1.14

I will update this!

@prbzrg
Copy link
Contributor Author

prbzrg commented Apr 29, 2024

downgrading to Lux v0.5.37 worked.

@prbzrg
Copy link
Contributor Author

prbzrg commented Apr 29, 2024

Error:

ERROR: LoadError: MethodError: no method matching fast_materialize(::Static.False, ::Static.False, ::Matrix{Float32})

Closest candidates are:
  fast_materialize(::SB, ::DB, ::Base.Broadcast.Broadcasted{S}) where {S, SB, DB}
   @ FastBroadcast C:\Users\prbzr\.julia\packages\FastBroadcast\ux5mz\src\FastBroadcast.jl:22

Stacktrace:
  [1] macro expansion
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::typeof(FastBroadcast.fast_materialize), ::Static.False, ::Static.False, ::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:81
  [3] __activation_gradient
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\utils.jl:187 [inlined]
  [4] LuxDL/LuxLib.jl#44
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\impl\fused_dense.jl:45 [inlined]
  [5] _pullback(ctx::Zygote.Context{false}, f::LuxLib.var"#44#47"{typeof(tanh_fast), typeof(identity), Matrix{}, Base.ReshapedArray{}, Matrix{}, SubArray{}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [6] ZBack
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
  [7] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:46 [inlined]
  [8] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:38 [inlined]
  [9] Pullback
    @ C:\Users\prbzr\.julia\packages\Lux\ANzxX\src\layers\basic.jl:218 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Matrix{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [11] Pullback
    @ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [13] #291
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [14] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{}}, Zygote.Pullback{Tuple{}, Tuple{}}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [15] #2169#back
    @ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [16] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{}}, Zygote.Pullback{Tuple{}, Tuple{}}}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [17] Pullback
    @ .\operators.jl:1045 [inlined]
 [18] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#1#2"}, Tuple{ComponentVector{Float32, Vector{}, Tuple{}}}, @Kwargs{}}, Any}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\operators.jl:1044 [inlined]
 [20] Pullback
    @ .\operators.jl:1041 [inlined]
 [21] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{…}, ComponentVector{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] #291
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [23] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{}, Tuple{}}, Zygote.Pullback{Tuple{}, Tuple{}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [24] #2169#back
    @ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [25] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{}, Tuple{}}, Zygote.Pullback{Tuple{}, Tuple{}}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [26] Pullback
    @ .\operators.jl:1041 [inlined]
 [27] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [28] #75
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91 [inlined]
 [29] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [30] withjacobian
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:150 [inlined]
 [31] _pullback(::Zygote.Context{false}, ::typeof(withjacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [32] _apply(::Function, ::Vararg{Any})
    @ Core .\boot.jl:838
 [33] adjoint
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:203 [inlined]
 [34] _pullback
    @ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
 [35] jacobian
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:128 [inlined]
 [36] _pullback(::Zygote.Context{false}, ::typeof(jacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [37] fn1
    @ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
 [38] _pullback(ctx::Zygote.Context{false}, f::typeof(fn1), args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [39] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:90
 [40] pullback
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:88 [inlined]
 [41] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:147
 [42] top-level scope
    @ D:\Codes\Mine\bug-report\br-3\br-3.jl:13
 [43] include(fname::String)
    @ Base.MainInclude .\client.jl:489
 [44] top-level scope
    @ REPL[1]:1
in expression starting at D:\Codes\Mine\bug-report\br-3\br-3.jl:13
Some type information was truncated. Use `show(err)` to see complete types.

MRE:

using ComponentArrays, Lux, Random, Zygote

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)

function fn1(z)
    sum(first(Zygote.jacobian(x -> first(nn(r, x, st)), z)))
end

fn1(ps)
Zygote.gradient(fn1, ps)

Environment:

Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
  [b0b7db55] ComponentArrays v0.15.11
  [b2108857] Lux v0.5.40
  [e88e6eb3] Zygote v0.6.69
  [9a3f8284] Random

@avik-pal
Copy link
Member

Yeah that is some weird Zygote broadcast handling quirk.

From v0.5.40 we use completely different backend operations which are faster and allocate significantly less but come at the cost of sacrificing nested reverse over reverse zygote AD (which to be fair, worked only in very limited cases and was never documented for a good reason)

https://lux.csail.mit.edu/stable/manual/nested_autodiff does nested AD for the inputs, but the same for parameters hasn't been implemented yet.

@avik-pal avik-pal transferred this issue from LuxDL/LuxLib.jl Apr 29, 2024
@avik-pal avik-pal changed the title ERROR: LoadError: MethodError: no method matching fast_materialize(::Static.False, ::Static.False, ::Matrix{Float32}) Nested AD for Parameter Gradient/Jacobian Apr 29, 2024
@vavrines
Copy link

I met the same issue when taking Zygote gradient of pullback

using Lux, Zygote, ComponentArrays, Random

X = collect(range(0, 1, length = 10)) |> permutedims
Y = zeros(axes(X))

nn = Chain(Dense(1 => 20, tanh), Dense(20 => 1))
ps, st = Lux.setup(Xoshiro(0), nn)
pv = ComponentArray(ps)

function loss(p)
    u(x) = 1 .+ nn(x, p, st)[1] .* x
    ux(x) = Zygote.pullback(u, x)[2](ones(size(x)))[1]

    pred = ux(X)
    loss = sum(abs2, pred)
    
    return loss
end

Zygote.gradient(loss, pv)

The code doesn't work since v0.5.40.

@prbzrg
Copy link
Contributor Author

prbzrg commented Apr 30, 2024

I added

[Lux]
DisableAutomaticNestedADSwitching = true

as LocalPreferences.toml , but it didn't change anything.

@prbzrg
Copy link
Contributor Author

prbzrg commented Apr 30, 2024

And also I didn't use StatefulLuxLayer, so I don't think the "Nested Automatic Differentiation" get activated.

@avik-pal
Copy link
Member

avik-pal commented Apr 30, 2024

I met the same issue when taking Zygote gradient of pullback

This is a separate issue. See #544 #600. FWIW pullback gradients are used extensively in DeepEquilibriumModels, see https://github.com/SciML/DeepEquilibriumNetworks.jl/blob/main/ext/DeepEquilibriumNetworksZygoteExt.jl. We just need to do the overload for DI.pullback.

The core problem is still the same. Zygote modifies Broadcast.broadcasted operations in a strange way that doesn't allow using FastBroadcast and such. (This part has nothing to do with the Nested AD rules that were introduced but rather #591). To fix this we just need to introduce an rrule for the parameter jacobian on Base.Fix1(::StatefulLuxLayer, x) similar to how the other jacobian and gradient calls are captured.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants