Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating Zygote.pullback #621

Closed
prbzrg opened this issue May 4, 2024 · 9 comments · Fixed by #623
Closed

Differentiating Zygote.pullback #621

prbzrg opened this issue May 4, 2024 · 9 comments · Fixed by #623

Comments

@prbzrg
Copy link
Contributor

prbzrg commented May 4, 2024

Error:

ERROR: LoadError: Mutating arrays is not supported -- called setindex!(Vector{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\array.jl:70
  [3] (::Zygote.var"#539#540"{Vector{Float32}})(::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float32}}})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
  [5] _mapreducedim!
    @ .\reducedim.jl:317 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(Base._mapreducedim!), typeof(identity), typeof(Base.add_sum), Vector{Float32}, Matrix{Float32}}, Any})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [7] mapreducedim!
    @ .\reducedim.jl:324 [inlined]
  [8] #sum!#852
    @ .\reducedim.jl:1034 [inlined]
  [9] (::Zygote.Pullback{Tuple{Base.var"##sum!#852", Bool, typeof(sum!), typeof(identity), Vector{Float32}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{}, Tuple{}}, Zygote.Pullback{Tuple{}, Any}}})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [10] sum!
    @ .\reducedim.jl:1034 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{init::Bool}, typeof(sum!), typeof(identity), Vector{Float32}, Matrix{Float32}}, Any})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [12] #sum!#853
    @ .\reducedim.jl:1036 [inlined]
 [13] (::Zygote.Pullback{Tuple{Base.var"##sum!#853", Bool, typeof(sum!), Vector{…}, Matrix{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}, Zygote.var"#2013#back#204"{…}}})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [14] sum!
    @ .\reducedim.jl:1036 [inlined]
 [15] __added_bias_gradient
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\utils.jl:184 [inlined]
 [16] __matmul_bias_partials
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\impl\fused_dense.jl:78 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [18] #46
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\impl\fused_dense.jl:47 [inlined]
 [19] (::Zygote.Pullback{Tuple{LuxLib.var"#46#49"{…}, Matrix{…}}, Any})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{Float32, 2, Tuple{…}}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [20] ZBack
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\api\dense.jl:46 [inlined]
 [23] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing, Nothing, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [24] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\api\dense.jl:38 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [26] Pullback
    @ C:\Users\prbzr\.julia\packages\Lux\ErEns\src\layers\basic.jl:218 [inlined]
 [27] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [28] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxCore\8lRV2\src\LuxCore.jl:180 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [30] Pullback
    @ C:\Users\prbzr\.julia\packages\Lux\ErEns\src\helpers\stateful.jl:83 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [32] Pullback
    @ D:\Codes\Mine\bug-report\br-3\br-3-3.jl:10 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [34] #75
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91 [inlined]
 [35] (::Zygote.Pullback{Tuple{Zygote.var"#75#76"{…}, Matrix{…}}, Tuple{Zygote.Pullback{…}, Zygote.var"#2180#back#303"{…}, Zygote.Pullback{…}}})(Δ::Tuple{FillArrays.Fill{Float32, 2, Tuple{…}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [36] fn1
    @ D:\Codes\Mine\bug-report\br-3\br-3-3.jl:11 [inlined]
 [37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [38] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [39] gradient(::Function, ::Matrix{Float32}, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [40] top-level scope
    @ D:\Codes\Mine\bug-report\br-3\br-3-3.jl:15
 [41] include(fname::String)
    @ Base.MainInclude .\client.jl:489
 [42] top-level scope
    @ REPL[1]:1
in expression starting at D:\Codes\Mine\bug-report\br-3\br-3-3.jl:15
Some type information was truncated. Use `show(err)` to see complete types.

MRE:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)

function fn1(u, p)
    z, back = Zygote.pullback(x -> snn(x, p), u)
    sum(only(back(z))) + sum(z)
end

fn1(r, ps)
Zygote.gradient(fn1, r, ps)

Environment:

Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
  [b0b7db55] ComponentArrays v0.15.11
  [f6369f11] ForwardDiff v0.10.36
  [b2108857] Lux v0.5.41
  [e88e6eb3] Zygote v0.6.69
  [9a3f8284] Random
@prbzrg
Copy link
Contributor Author

prbzrg commented May 4, 2024

It works with Lux v0.5.37.
How can I do it with Lux v0.5.41?
Is this a bug? Did I miss something?

@prbzrg
Copy link
Contributor Author

prbzrg commented May 4, 2024

Maybe related to #286

@avik-pal
Copy link
Member

avik-pal commented May 4, 2024

pullbacks are hard to differentiate directly. See #610 (comment). We just need some rrules for DifferentiationInterface.pullback. Zygote.pullback is almost never going to work, unless someone can use some nice trick to write the tangent for the pullback function. DI.pullback on the other hand is quite simple, DEQs.jl already does that

@avik-pal avik-pal changed the title ERROR: LoadError: Mutating arrays is not supported -- called setindex!(Vector{Float32}, ...) Differentiating Zygote.pullback May 4, 2024
@prbzrg
Copy link
Contributor Author

prbzrg commented May 6, 2024

I'm still getting the error with

(br-3) pkg> st
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
  [b0b7db55] ComponentArrays v0.15.11
  [a0c0ee7d] DifferentiationInterface v0.3.3
  [f6369f11] ForwardDiff v0.10.36
  [b2108857] Lux v0.5.42
  [e88e6eb3] Zygote v0.6.69
  [9a3f8284] Random

(br-3) pkg> st --outdated
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`

(br-3) pkg> st --outdated -m
Status `D:\Codes\Mine\bug-report\br-3\Manifest.toml`

Even after using DI:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)

function fn1(u, p)
    z, uJ = DifferentiationInterface.value_and_pullback(x -> snn(x, p), AutoZygote(), u, u)
    sum(uJ) + sum(z)
end

fn1(r, ps)
DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps)

@avik-pal
Copy link
Member

avik-pal commented May 6, 2024

Use Lux.vector_jacobian_product

@prbzrg
Copy link
Contributor Author

prbzrg commented May 6, 2024

I tried:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)

function fn1(u, p)
    z, uJ = Lux.vector_jacobian_product(x -> snn(x, p), AutoZygote(), u, u)
    sum(uJ) + sum(z)
end

fn1(r, ps)
# DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps)

But the problem is still there. 😬

@avik-pal
Copy link
Member

avik-pal commented May 6, 2024

Closures don't work, see the first part in https://lux.csail.mit.edu/stable/manual/nested_autodiff. Also https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.vector_jacobian_product returns only the vjp not the value and vjp (which can be added later but doesn't affect the code by much)

@prbzrg
Copy link
Contributor Author

prbzrg commented May 6, 2024

Thanks, the problem is resolved.
final code:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)

function fn1(u, ps, st)
    snn = StatefulLuxLayer(nn, ps, st)
    z = snn(u)
    uJ = Lux.vector_jacobian_product(snn, AutoZygote(), u, u)
    sum(uJ) + sum(z)
end

fn1(r, ps, st)
# DifferentiationInterface.gradient(x -> fn1(r, x, st), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps, st)

@avik-pal
Copy link
Member

avik-pal commented May 6, 2024

Capturing DI would have been the ideal situation but it causes ambiguities and I would have to manually define the functions for all possibilities which will get messy #600 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants