Skip to content

Commit

Permalink
Merge pull request #2207 from SciML/myb/dde
Browse files Browse the repository at this point in the history
Add DDE support in `System`
  • Loading branch information
ChrisRackauckas authored Jul 12, 2023
2 parents 47d8f05 + d76d73f commit e63aad0
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 13 deletions.
163 changes: 151 additions & 12 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
implicit_dae = false,
ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) :
nothing,
isdde = false,
has_difference = false,
kwargs...)
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
if isdde
eqs = delay_to_function(sys)
else
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
end
if !implicit_dae
check_operator_variables(eqs, Differential)
check_lhs(eqs, Differential, Set(dvs))
Expand All @@ -136,15 +141,59 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
p = map(x -> time_varying_as_func(value(x), sys), ps)
t = get_iv(sys)

pre, sol_states = get_substitutions_and_solved_states(sys,
no_postprocess = has_difference)
if isdde
build_function(rhss, u, DDE_HISTORY_FUN, p, t; kwargs...)
else
pre, sol_states = get_substitutions_and_solved_states(sys,
no_postprocess = has_difference)

if implicit_dae
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
if implicit_dae
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
end
end
end

function isdelay(var, iv)
iv === nothing && return false
isvariable(var) || return false
if istree(var) && !ModelingToolkit.isoperator(var, Symbolics.Operator)
args = arguments(var)
length(args) == 1 || return false
isequal(args[1], iv) || return true
end
return false
end
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
function delay_to_function(sys::AbstractODESystem)
delay_to_function(full_equations(sys),
get_iv(sys),
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))),
parameters(sys),
DDE_HISTORY_FUN)
end
function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h)
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
end
function delay_to_function(eq::Equation, iv, sts, ps, h)
delay_to_function(eq.lhs, iv, sts, ps, h) ~ delay_to_function(eq.rhs, iv, sts, ps, h)
end
function delay_to_function(expr, iv, sts, ps, h)
if isdelay(expr, iv)
v = operation(expr)
time = arguments(expr)[1]
idx = sts[v]
return term(getindex, h(Sym{Any}(:ˍ₋arg3), time), idx, type = Real) # BIG BIG HACK
elseif istree(expr)
return similarterm(expr,
operation(expr),
map(x -> delay_to_function(x, iv, sts, ps, h), arguments(expr)))
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
return expr
end
end

Expand Down Expand Up @@ -485,6 +534,30 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
observed = observedfun)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
DDEFunction{true}(sys, args...; kwargs...)
end

function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
ps = parameters(sys), u0 = nothing;
eval_module = @__MODULE__,
checkbounds = false,
kwargs...) where {iip}
f_gen = generate_function(sys, dvs, ps; isdde = true,
expression = Val{true},
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
f(u, p, h, t) = f_oop(u, p, h, t)
f(du, u, p, h, t) = f_iip(du, u, p, h, t)

DDEFunction{iip}(f,
sys = sys,
syms = Symbol.(dvs),
indepsym = Symbol(get_iv(sys)),
paramsyms = Symbol.(ps))
end

"""
```julia
ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
Expand Down Expand Up @@ -577,9 +650,14 @@ end
"""
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
"""
function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
function get_u0_p(sys,
u0map,
parammap;
use_union = false,
tofloat = !use_union,
symbolic_u0 = false)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
Expand All @@ -588,7 +666,11 @@ function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
defs = mergedefaults(defs, parammap, ps)
defs = mergedefaults(defs, u0map, dvs)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
if symbolic_u0
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
else
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
end
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
p = p === nothing ? SciMLBase.NullParameters() : p
u0, p, defs
Expand All @@ -604,13 +686,14 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
eval_expression = true,
use_union = false,
tofloat = !use_union,
symbolic_u0 = false,
kwargs...)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
iv = get_iv(sys)

u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)

if implicit_dae && du0map !== nothing
ddvs = map(Differential(iv), dvs)
Expand Down Expand Up @@ -802,6 +885,62 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
end
end

function generate_history(sys::AbstractODESystem, u0; kwargs...)
build_function(u0, parameters(sys), get_iv(sys); expression = Val{false}, kwargs...)
end

function DiffEqBase.DDEProblem(sys::AbstractODESystem, args...; kwargs...)
DDEProblem{true}(sys, args...; kwargs...)
end
function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
callback = nothing,
check_length = true,
kwargs...) where {iip}
has_difference = any(isdifferenceeq, equations(sys))
f, u0, p = process_DEProblem(DDEFunction{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
has_difference = has_difference,
symbolic_u0 = true,
check_length, kwargs...)
h_oop, h_iip = generate_history(sys, u0)
h = h_oop
u0 = h(p, tspan[1])
cbs = process_events(sys; callback, has_difference, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
else
error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
end

"""
```julia
ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
Expand Down
7 changes: 6 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."

iv′ = value(scalarize(iv))
dvs′ = value.(scalarize(dvs))
ps′ = value.(scalarize(ps))
ctrl′ = value.(scalarize(controls))
dvs′ = value.(scalarize(dvs))
dvs′ = filter(x -> !isdelay(x, iv), dvs′)

if !(isempty(default_u0) && isempty(default_p))
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
Expand Down Expand Up @@ -258,6 +259,10 @@ function ODESystem(eqs, iv = nothing; kwargs...)
push!(algeeq, eq)
end
end
for v in allstates
isdelay(v, iv) || continue
collect_vars!(allstates, ps, arguments(v)[1], iv)
end
algevars = setdiff(allstates, diffvars)
# the orders here are very important!
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
Expand Down
2 changes: 2 additions & 0 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ end
function TearingState(sys; quick_cancel = false, check = true)
sys = flatten(sys)
ivs = independent_variables(sys)
iv = length(ivs) == 1 ? ivs[1] : nothing
eqs = copy(equations(sys))
neqs = length(eqs)
dervaridxs = OrderedSet{Int}()
Expand Down Expand Up @@ -287,6 +288,7 @@ function TearingState(sys; quick_cancel = false, check = true)
isalgeq = true
statevars = []
for var in vars
ModelingToolkit.isdelay(var, iv) && continue
set_incidence = true
@label ANOTHER_VAR
_var, _ = var_from_nested_derivative(var)
Expand Down
51 changes: 51 additions & 0 deletions test/dde.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using ModelingToolkit, DelayDiffEq, Test
p0 = 0.2;
q0 = 0.3;
v0 = 1;
d0 = 5;
p1 = 0.2;
q1 = 0.3;
v1 = 1;
d1 = 1;
d2 = 1;
beta0 = 1;
beta1 = 1;
tau = 1;
function bc_model(du, u, h, p, t)
du[1] = (v0 / (1 + beta0 * (h(p, t - tau)[3]^2))) * (p0 - q0) * u[1] - d0 * u[1]
du[2] = (v0 / (1 + beta0 * (h(p, t - tau)[3]^2))) * (1 - p0 + q0) * u[1] +
(v1 / (1 + beta1 * (h(p, t - tau)[3]^2))) * (p1 - q1) * u[2] - d1 * u[2]
du[3] = (v1 / (1 + beta1 * (h(p, t - tau)[3]^2))) * (1 - p1 + q1) * u[2] - d2 * u[3]
end
lags = [tau]
h(p, t) = ones(3)
h2(p, t) = ones(3) .- t * q1 * 10
tspan = (0.0, 10.0)
u0 = [1.0, 1.0, 1.0]
prob = DDEProblem(bc_model, u0, h, tspan, constant_lags = lags)
alg = MethodOfSteps(Vern9())
sol = solve(prob, alg, reltol = 1e-7, abstol = 1e-10)
prob2 = DDEProblem(bc_model, u0, h2, tspan, constant_lags = lags)
sol2 = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)

@parameters p0=0.2 p1=0.2 q0=0.3 q1=0.3 v0=1 v1=1 d0=5 d1=1 d2=1 beta0=1 beta1=1
@variables t x₀(t) x₁(t) x₂(..)
tau = 1
D = Differential(t)
eqs = [D(x₀) ~ (v0 / (1 + beta0 * (x₂(t - tau)^2))) * (p0 - q0) * x₀ - d0 * x₀
D(x₁) ~ (v0 / (1 + beta0 * (x₂(t - tau)^2))) * (1 - p0 + q0) * x₀ +
(v1 / (1 + beta1 * (x₂(t - tau)^2))) * (p1 - q1) * x₁ - d1 * x₁
D(x₂(t)) ~ (v1 / (1 + beta1 * (x₂(t - tau)^2))) * (1 - p1 + q1) * x₁ - d2 * x₂(t)]
@named sys = System(eqs)
prob = DDEProblem(sys,
[x₀ => 1.0, x₁ => 1.0, x₂(t) => 1.0],
tspan,
constant_lags = [tau])
sol_mtk = solve(prob, alg, reltol = 1e-7, abstol = 1e-10)
@test sol_mtk.u[end] sol.u[end]
prob2 = DDEProblem(sys,
[x₀ => 1.0 - t * q1 * 10, x₁ => 1.0 - t * q1 * 10, x₂(t) => 1.0 - t * q1 * 10],
tspan,
constant_lags = [tau])
sol2_mtk = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)
@test sol2_mtk.u[end] sol2.u[end]

0 comments on commit e63aad0

Please sign in to comment.