Skip to content

Commit

Permalink
Merge pull request #2208 from SciML/myb/sdde
Browse files Browse the repository at this point in the history
SDDE support
  • Loading branch information
YingboMa authored Jul 13, 2023
2 parents e63aad0 + 46c8a2f commit 24d0d7c
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 11 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
StochasticDelayDiffEq = "29a0d76e-afc8-11e9-03a4-eda52ae4b960"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq"]
105 changes: 100 additions & 5 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ function isdelay(var, iv)
return false
end
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
function delay_to_function(sys::AbstractODESystem)
delay_to_function(full_equations(sys),
function delay_to_function(sys::AbstractODESystem, eqs = full_equations(sys))
delay_to_function(eqs,
get_iv(sys),
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))),
parameters(sys),
DDE_HISTORY_FUN)
end
function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h)
function delay_to_function(eqs::Vector, iv, sts, ps, h)
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
end
function delay_to_function(eq::Equation, iv, sts, ps, h)
Expand Down Expand Up @@ -548,8 +548,8 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
f(u, p, h, t) = f_oop(u, p, h, t)
f(du, u, p, h, t) = f_iip(du, u, p, h, t)
f(u, h, p, t) = f_oop(u, h, p, t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)

DDEFunction{iip}(f,
sys = sys,
Expand All @@ -558,6 +558,35 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
paramsyms = Symbol.(ps))
end

function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...)
SDDEFunction{true}(sys, args...; kwargs...)
end

function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
ps = parameters(sys), u0 = nothing;
eval_module = @__MODULE__,
checkbounds = false,
kwargs...) where {iip}
f_gen = generate_function(sys, dvs, ps; isdde = true,
expression = Val{true},
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
isdde = true, kwargs...)
g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
f(u, h, p, t) = f_oop(u, h, p, t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
g(u, h, p, t) = g_oop(u, h, p, t)
g(du, u, h, p, t) = g_iip(du, u, h, p, t)

SDDEFunction{iip}(f, g,
sys = sys,
syms = Symbol.(dvs),
indepsym = Symbol(get_iv(sys)),
paramsyms = Symbol.(ps))
end

"""
```julia
ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
Expand Down Expand Up @@ -941,6 +970,72 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
end

function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
SDDEProblem{true}(sys, args...; kwargs...)
end
function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
callback = nothing,
check_length = true,
sparsenoise = nothing,
kwargs...) where {iip}
has_difference = any(isdifferenceeq, equations(sys))
f, u0, p = process_DEProblem(SDDEFunction{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
has_difference = has_difference,
symbolic_u0 = true,
check_length, kwargs...)
h_oop, h_iip = generate_history(sys, u0)
h(out, p, t) = h_iip(out, p, t)
h(p, t) = h_oop(p, t)
u0 = h(p, tspan[1])
cbs = process_events(sys; callback, has_difference, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
else
error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end

noiseeqs = get_noiseeqs(sys)
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
if noiseeqs isa AbstractVector
noise_rate_prototype = nothing
elseif sparsenoise
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
else
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
end
SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype =
noise_rate_prototype, kwargs1..., kwargs...)
end

"""
```julia
ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
Expand Down
17 changes: 12 additions & 5 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,18 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
end

function generate_diffusion_function(sys::SDESystem, dvs = states(sys),
ps = parameters(sys); kwargs...)
return build_function(get_noiseeqs(sys),
map(x -> time_varying_as_func(value(x), sys), dvs),
map(x -> time_varying_as_func(value(x), sys), ps),
get_iv(sys); kwargs...)
ps = parameters(sys); isdde = false, kwargs...)
eqs = get_noiseeqs(sys)
if isdde
eqs = delay_to_function(sys, eqs)
end
u = map(x -> time_varying_as_func(value(x), sys), dvs)
p = map(x -> time_varying_as_func(value(x), sys), ps)
if isdde
return build_function(eqs, u, DDE_HISTORY_FUN, p, get_iv(sys); kwargs...)
else
return build_function(eqs, u, p, get_iv(sys); kwargs...)
end
end

"""
Expand Down
33 changes: 33 additions & 0 deletions test/dde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,36 @@ prob2 = DDEProblem(sys,
constant_lags = [tau])
sol2_mtk = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)
@test sol2_mtk.u[end] sol2.u[end]

using StochasticDelayDiffEq
function hayes_modelf(du, u, h, p, t)
τ, a, b, c, α, β, γ = p
du .= a .* u .+ b .* h(p, t - τ) .+ c
end
function hayes_modelg(du, u, h, p, t)
τ, a, b, c, α, β, γ = p
du .= α .* u .+ γ
end
h(p, t) = (ones(1) .+ t);
tspan = (0.0, 10.0)

pmul = [1.0,
-4.0, -2.0, 10.0,
-1.3, -1.2, 1.1]

prob = SDDEProblem(hayes_modelf, hayes_modelg, [1.0], h, tspan, pmul;
constant_lags = (pmul[1],));
sol = solve(prob, RKMil())

@variables t x(..)
@parameters a=-4.0 b=-2.0 c=10.0 α=-1.3 β=-1.2 γ=1.1
D = Differential(t)
@brownian η
τ = 1.0
eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c +* x(t) + γ) * η]
@named sys = System(eqs)
sys = structural_simplify(sys)
@test equations(sys) == [D(x(t)) ~ a * x(t) + b * x(t - τ) + c]
@test isequal(ModelingToolkit.get_noiseeqs(sys), [α * x(t) + γ;;])
prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,));
@test_nowarn sol_mtk = solve(prob_mtk, RKMil())

0 comments on commit 24d0d7c

Please sign in to comment.