Skip to content

Commit

Permalink
feat: purge SparseDiffTools in-favor of SciMLJacobianOperators
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 30, 2024
1 parent 5d0c695 commit 6f80c63
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 14 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down Expand Up @@ -67,7 +68,7 @@ Functors = "0.4"
GPUArraysCore = "0.1"
LinearAlgebra = "1.10"
LinearSolve = "2"
Lux = "0.5.51"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9"
NLsolve = "4.5.1"
Expand All @@ -85,6 +86,7 @@ Reexport = "1.0"
ReverseDiff = "1.15.1"
SafeTestsets = "0.1.0"
SciMLBase = "2.51.4"
SciMLJacobianOperators = "0.1"
SciMLOperators = "0.3"
SciMLStructures = "1.3"
SparseArrays = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Enzyme = "0.12, 0.13"
Flux = "0.14"
ForwardDiff = "0.10"
IterTools = "1"
Lux = "0.5.7, 1"
Lux = "1"
LuxCUDA = "0.3"
Optimization = "3.9, 4"
OptimizationOptimJL = "0.2, 0.3, 0.4"
Expand Down
4 changes: 2 additions & 2 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using RandomNumbers: Xorshifts
using RecursiveArrayTools: RecursiveArrayTools, AbstractDiffEqArray,
AbstractVectorOfArray, ArrayPartition, DiffEqArray,
VectorOfArray
# using SciMLJacobianOperators: VecJacOperator # TODO: Replace uses of VecJac
using SciMLJacobianOperators: VecJacOperator, StatefulJacobianOperator
using SciMLStructures: SciMLStructures, canonicalize, Tunable, isscimlstructure
using SymbolicIndexingInterface: SymbolicIndexingInterface, current_time, getu,
parameter_values, state_values
Expand All @@ -32,7 +32,7 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm,
AbstractShadowingSensitivityAlgorithm, AbstractTimeseriesSolution,
AbstractNonlinearProblem, AbstractSensitivityAlgorithm,
AbstractDiffEqFunction, AbstractODEFunction, unwrapped_f, CallbackSet,
ContinuousCallback, DESolution,
ContinuousCallback, DESolution, NonlinearFunction, NonlinearProblem,
DiscreteCallback, LinearProblem, ODEFunction, ODEProblem,
RODEFunction, RODEProblem, ReturnCode, SDEFunction,
SDEProblem, VectorContinuousCallback, deleteat!,
Expand Down
4 changes: 2 additions & 2 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1311,8 +1311,8 @@ struct ForwardDiffOverAdjoint{A} <:
adjalg::A
end

function get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where {compile}
AutoReverseDiff(; compile)
function get_autodiff_from_vjp(::ReverseDiffVJP{compile}) where {compile}
return AutoReverseDiff(; compile)
end
get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote()
get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme()
Expand Down
19 changes: 14 additions & 5 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,21 @@ end
end

if !needs_jac
# Current SciMLJacobianOperators requires specifying the problem as a NonlinearProblem
usize = size(y)
__f = y -> vec(f(reshape(y, usize), p, nothing))
operator = VecJac(__f, vec(y);
autodiff = get_autodiff_from_vjp(sensealg.autojacvec))
linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ))
solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ)
if SciMLBase.isinplace(f)
nlfunc = NonlinearFunction{true}((du, u, p) -> unwrapped_f(f)(
reshape(u, usize), reshape(u, usize), p, nothing))
else
nlfunc = NonlinearFunction{false}((u, p) -> unwrapped_f(f)(
reshape(u, usize), p, nothing))
end
nlprob = NonlinearProblem(nlfunc, vec(λ), p)
operator = VecJacOperator(
nlprob, vec(y), (λ); autodiff = get_autodiff_from_vjp(sensealg.autojacvec))
soperator = StatefulJacobianOperator(operator, vec(λ), p)
linear_problem = LinearProblem(soperator, vec(dgdu_val); u0 = vec(λ))
solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...)
else
if linsolve === nothing && isempty(sensealg.linsolve_kwargs)
# For the default case use `\` to avoid any form of unnecessary cache allocation
Expand Down
6 changes: 3 additions & 3 deletions test/gpu/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"

[compat]
CUDA = "3.12, 4, 5"
DiffEqCallbacks = "2.24, 3, 4"
DiffEqFlux = "3, 4"
CUDA = "5"
DiffEqCallbacks = "4"
DiffEqFlux = "4"
LuxCUDA = "0.3.1"

0 comments on commit 6f80c63

Please sign in to comment.