From e4e4f9f2acce94bedd04b28587951450c6cb1eaa Mon Sep 17 00:00:00 2001 From: apkille Date: Fri, 16 Aug 2024 13:41:03 -0400 Subject: [PATCH 1/2] add support for non-AbstractArrays --- src/concrete_solve.jl | 4 ++ src/gauss_adjoint.jl | 6 ++- test/noindex_tests.jl | 97 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 test/noindex_tests.jl diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index d5bdc78a5..5caf71642 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -946,6 +946,8 @@ function DiffEqBase._concrete_solve_adjoint( if !(Δ isa NoTangent || v isa ZeroTangent) if u0 isa Number ForwardDiff.value.(J'v) + elseif v isa Tangent + ForwardDiff.value.(J'vec(v.x)) else ForwardDiff.value.(J'vec(v)) end @@ -1110,6 +1112,8 @@ function DiffEqBase._concrete_solve_adjoint( if !(Δ isa NoTangent || v isa ZeroTangent) if u0 isa Number ForwardDiff.value.(J'v) + elseif v isa Tangent + ForwardDiff.value.(J'vec(v.x)) else ForwardDiff.value.(J'vec(v)) end diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..742fb4679 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -436,7 +436,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) pJ = nothing else pf = DiffEqBase.ParamJacobianWrapper(unwrappedf, tspan[1], y) - pJ = similar(u0, length(u0), numparams) + if isa(u0, AbstractArray) + pJ = similar(u0, length(u0), numparams) + else + pJ = similar(Array{eltype(u0)}, length(u0), numparams) + end paramjac_config = build_param_jac_config(sensealg, pf, y, p) end diff --git a/test/noindex_tests.jl b/test/noindex_tests.jl new file mode 100644 index 000000000..fa9ca9a6d --- /dev/null +++ b/test/noindex_tests.jl @@ -0,0 +1,97 @@ +using OrdinaryDiffEq, RecursiveArrayTools, LinearAlgebra +using Zygote, SciMLSensitivity, Random, Test + +struct CustomArray{T, N} + x::Array{T, N} +end +Base.size(x::CustomArray) = size(x.x) +Base.axes(x::CustomArray) = axes(x.x) +Base.ndims(x::CustomArray) = ndims(x.x) +Base.ndims(::Type{<:CustomArray{T,N}}) where {T,N} = N +Base.zero(x::CustomArray) = CustomArray(zero(x.x)) +Base.zero(::Type{<:CustomArray{T,N}}) where {T,N} = CustomArray(zero(Array{T,N})) +Base.similar(x::CustomArray, dims::Union{Integer, AbstractUnitRange}...) = CustomArray(similar(x.x, dims...)) +Base.copyto!(x::CustomArray, y::CustomArray) = CustomArray(copyto!(x.x, y.x)) +Base.copy(x::CustomArray) = CustomArray(copy(x.x)) +Base.length(x::CustomArray) = length(x.x) +Base.isempty(x::CustomArray) = isempty(x.x) +Base.eltype(x::CustomArray) = eltype(x.x) +Base.zero(x::CustomArray) = CustomArray(zero(x.x)) +Base.fill!(x::CustomArray, y) = CustomArray(fill!(x.x, y)) +Base.getindex(x::CustomArray, i) = getindex(x.x, i) +Base.setindex!(x::CustomArray, v, idx) = setindex!(x.x, v, idx) +Base.firstindex(x::CustomArray) = firstindex(x.x) +Base.lastindex(x::CustomArray) = lastindex(x.x) +Base.eachindex(x::CustomArray) = eachindex(x.x) +Base.mapreduce(f, op, x::CustomArray; kwargs...) = mapreduce(f, op, x.x; kwargs...) +Base.any(f::Function, x::CustomArray; kwargs...) = any(f, x.x; kwargs...) +Base.all(f::Function, x::CustomArray; kwargs...) = all(f, x.x; kwargs...) +Base.similar(x::CustomArray, t) = CustomArray(similar(x.x, t)) +Base.:(+)(x::CustomArray, y::CustomArray) = CustomArray(x.x + y.x) +Base.:(==)(x::CustomArray, y::CustomArray) = x.x == y.x +Base.:(*)(x::Number, y::CustomArray) = CustomArray(x*y.x) +Base.:(/)(x::CustomArray, y::Number) = CustomArray(x.x/y) +LinearAlgebra.norm(x::CustomArray) = norm(x.x) +LinearAlgebra.vec(x::CustomArray) = CustomArray(vec(x.x)) + +struct CustomStyle{N} <: Broadcast.BroadcastStyle where {N} end +CustomStyle(::Val{N}) where N = CustomStyle{N}() +CustomStyle{M}(::Val{N}) where {N,M} = NoIndexStyle{N}() +Base.BroadcastStyle(::Type{<:CustomArray{T,N}}) where {T,N} = CustomStyle{N}() +Broadcast.BroadcastStyle(::CustomStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where {N} = CustomStyle{N}() +Base.similar(bc::Base.Broadcast.Broadcasted{CustomStyle{N}}, ::Type{ElType}) where {N, ElType} = CustomArray(similar(Array{ElType, N}, axes(bc))) +Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::CustomArray, i) = x.x[i] +Base.Broadcast.extrude(x::CustomArray) = x +Base.Broadcast.broadcastable(x::CustomArray) = x + +@inline function Base.copyto!(dest::CustomArray, bc::Base.Broadcast.Broadcasted{<:Union{Base.Broadcast.AbstractArrayStyle,CustomStyle}}) + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.x + @simd for I in 1:length(dest′) + @inbounds dest′[I] = bc′[I] + end + return dest +end +@inline function Base.copy(bc::Base.Broadcast.Broadcasted{<:CustomStyle}) + bcf = Broadcast.flatten(bc) + x = find_x(bcf) + data = similar(x, eltype(bcf[1])) + @inbounds @simd for I in 1:length(x) + data[I] = bcf[I] + end + return CustomArray(data) +end +find_x(bc::Broadcast.Broadcasted) = find_x(bc.args) +find_x(args::Tuple) = find_x(find_x(args[1]), Base.tail(args)) +find_x(x) = x +find_x(::Any, rest) = find_x(rest) +find_x(x::CustomArray, rest) = x.x + +RecursiveArrayTools.recursive_unitless_bottom_eltype(x::CustomArray) = eltype(x) +RecursiveArrayTools.recursivecopy!(dest::CustomArray, src::CustomArray) = copyto!(dest, src) +RecursiveArrayTools.recursivecopy(x::CustomArray) = copy(x) +RecursiveArrayTools.recursivefill!(x::CustomArray, a) = fill!(x, a) + +Base.show_vector(io::IO, x::CustomArray) = Base.show_vector(io, x.x) + +Base.show(io::IO, x::CustomArray) = (print(io, "CustomArray");show(io, x.x)) +function Base.show(io::IO, ::MIME"text/plain", x::CustomArray) + println(io, Base.summary(x), ":") + Base.print_array(io, x.x) +end + +ca0 = CustomArray(ones(2)) +tspan = (0.0, 1.0) +par = [rand(), rand()] + +algs = [Tsit5(), BS3(), Vern9(), DP5()] + +for alg in algs + function cost(p) + prob = ODEProblem((du, u, p, t) -> (du[1] = p[1]*u[1] + p[2]*u[2]; du[2] = p[2]*u[1]), ca0, tspan, p) + sol = solve(prob, alg; save_everystep = false) + return 1 - norm(sol[end])^2 + end + @test_nowarn Zygote.gradient(cost, par) +end \ No newline at end of file From 5824180df9633a5a995cce23a6794f0013f83937 Mon Sep 17 00:00:00 2001 From: apkille Date: Mon, 2 Sep 2024 05:28:42 -0400 Subject: [PATCH 2/2] rm gauss_adjoint.jl changes --- src/gauss_adjoint.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 742fb4679..656556ad3 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -436,11 +436,7 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) pJ = nothing else pf = DiffEqBase.ParamJacobianWrapper(unwrappedf, tspan[1], y) - if isa(u0, AbstractArray) - pJ = similar(u0, length(u0), numparams) - else - pJ = similar(Array{eltype(u0)}, length(u0), numparams) - end + pJ = similar(u0, length(u0), numparams) paramjac_config = build_param_jac_config(sensealg, pf, y, p) end