From 008071d59a855ceab9a29241342fc67e72dd514d Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 12 Jul 2023 19:45:23 -0400 Subject: [PATCH 1/2] SDDE support --- src/systems/diffeqs/abstractodesystem.jl | 106 +++++++++++++++++++++-- src/systems/diffeqs/sdesystem.jl | 17 ++-- test/dde.jl | 32 +++++++ 3 files changed, 145 insertions(+), 10 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 392aa0e1c7..e436dde725 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -169,14 +169,14 @@ function isdelay(var, iv) return false end const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___) -function delay_to_function(sys::AbstractODESystem) - delay_to_function(full_equations(sys), +function delay_to_function(sys::AbstractODESystem, eqs = full_equations(sys)) + delay_to_function(eqs, get_iv(sys), Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))), parameters(sys), DDE_HISTORY_FUN) end -function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h) +function delay_to_function(eqs::Vector, iv, sts, ps, h) delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,)) end function delay_to_function(eq::Equation, iv, sts, ps, h) @@ -548,8 +548,8 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), expression_module = eval_module, checkbounds = checkbounds, kwargs...) f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) - f(u, p, h, t) = f_oop(u, p, h, t) - f(du, u, p, h, t) = f_iip(du, u, p, h, t) + f(u, h, p, t) = f_oop(u, h, p, t) + f(du, u, h, p, t) = f_iip(du, u, h, p, t) DDEFunction{iip}(f, sys = sys, @@ -558,6 +558,36 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), paramsyms = Symbol.(ps)) end +function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...) + SDDEFunction{true}(sys, args...; kwargs...) +end + +function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), + ps = parameters(sys), u0 = nothing; + eval_module = @__MODULE__, + checkbounds = false, + kwargs...) where {iip} + f_gen = generate_function(sys, dvs, ps; isdde = true, + expression = Val{true}, + expression_module = eval_module, checkbounds = checkbounds, + kwargs...) + f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) + g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true}, + isdde = true, kwargs...) + @show g_gen[2] + g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen) + f(u, h, p, t) = f_oop(u, h, p, t) + f(du, u, h, p, t) = f_iip(du, u, h, p, t) + g(u, h, p, t) = g_oop(u, h, p, t) + g(du, u, h, p, t) = g_iip(du, u, h, p, t) + + SDDEFunction{iip}(f, g, + sys = sys, + syms = Symbol.(dvs), + indepsym = Symbol(get_iv(sys)), + paramsyms = Symbol.(ps)) +end + """ ```julia ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys), @@ -941,6 +971,72 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...) end +function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...) + SDDEProblem{true}(sys, args...; kwargs...) +end +function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], + tspan = get_tspan(sys), + parammap = DiffEqBase.NullParameters(); + callback = nothing, + check_length = true, + sparsenoise = nothing, + kwargs...) where {iip} + has_difference = any(isdifferenceeq, equations(sys)) + f, u0, p = process_DEProblem(SDDEFunction{iip}, sys, u0map, parammap; + t = tspan !== nothing ? tspan[1] : tspan, + has_difference = has_difference, + symbolic_u0 = true, + check_length, kwargs...) + h_oop, h_iip = generate_history(sys, u0) + h(out, p, t) = h_iip(out, p, t) + h(p, t) = h_oop(p, t) + u0 = h(p, tspan[1]) + cbs = process_events(sys; callback, has_difference, kwargs...) + if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing + affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...) + discrete_cbs = map(affects, clocks, svs) do affect, clock, sv + if clock isa Clock + PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt) + else + error("$clock is not a supported clock type.") + end + end + if cbs === nothing + if length(discrete_cbs) == 1 + cbs = only(discrete_cbs) + else + cbs = CallbackSet(discrete_cbs...) + end + else + cbs = CallbackSet(cbs, discrete_cbs) + end + else + svs = nothing + end + kwargs = filter_kwargs(kwargs) + + kwargs1 = (;) + if cbs !== nothing + kwargs1 = merge(kwargs1, (callback = cbs,)) + end + if svs !== nothing + kwargs1 = merge(kwargs1, (disc_saved_values = svs,)) + end + + noiseeqs = get_noiseeqs(sys) + sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false)) + if noiseeqs isa AbstractVector + noise_rate_prototype = nothing + elseif sparsenoise + I, J, V = findnz(SparseArrays.sparse(noiseeqs)) + noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0))) + else + noise_rate_prototype = zeros(eltype(u0), size(noiseeqs)) + end + SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype = + noise_rate_prototype, kwargs1..., kwargs...) +end + """ ```julia ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan, diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index b3cdef8d23..b1f9b945b2 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -210,11 +210,18 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem) end function generate_diffusion_function(sys::SDESystem, dvs = states(sys), - ps = parameters(sys); kwargs...) - return build_function(get_noiseeqs(sys), - map(x -> time_varying_as_func(value(x), sys), dvs), - map(x -> time_varying_as_func(value(x), sys), ps), - get_iv(sys); kwargs...) + ps = parameters(sys); isdde = false, kwargs...) + eqs = get_noiseeqs(sys) + if isdde + eqs = delay_to_function(sys, eqs) + end + u = map(x -> time_varying_as_func(value(x), sys), dvs) + p = map(x -> time_varying_as_func(value(x), sys), ps) + if isdde + return build_function(eqs, u, DDE_HISTORY_FUN, p, get_iv(sys); kwargs...) + else + return build_function(eqs, u, p, get_iv(sys); kwargs...) + end end """ diff --git a/test/dde.jl b/test/dde.jl index aad39fe27f..b62c6eb214 100644 --- a/test/dde.jl +++ b/test/dde.jl @@ -49,3 +49,35 @@ prob2 = DDEProblem(sys, constant_lags = [tau]) sol2_mtk = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10) @test sol2_mtk.u[end] ≈ sol2.u[end] + +using StochasticDelayDiffEq +function hayes_modelf(du, u, h, p, t) + τ, a, b, c, α, β, γ = p + du .= a .* u .+ b .* h(p, t - τ) .+ c +end +function hayes_modelg(du, u, h, p, t) + τ, a, b, c, α, β, γ = p + du .= α .* u .+ γ +end +h(p, t) = (ones(1) .+ t); +tspan = (0.0, 10.0) + +pmul = [1.0, + -4.0, -2.0, 10.0, + -1.3, -1.2, 1.1] + +prob = SDDEProblem(hayes_modelf, hayes_modelg, [1.0], h, tspan, pmul; + constant_lags = (pmul[1],)); +sol = solve(prob, RKMil()) + +@variables t x(..) +@parameters a=-4.0 b=-2.0 c=10.0 α=-1.3 β=-1.2 γ=1.1 +D = Differential(t) +@brownian η +τ = 1.0 +eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c + (α * x(t) + γ) * η] +@named sys = System(eqs) +sys = structural_simplify(sys) +prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,)); +sol_mtk = solve(prob_mtk, RKMil()) +@test sol.u[end] ≈ sol_mtk.u[end] From 46c8a2f2c6c3e9f29cec39b53a7186bb9d5201ba Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 12 Jul 2023 19:51:14 -0400 Subject: [PATCH 2/2] Working tests --- Project.toml | 3 ++- src/systems/diffeqs/abstractodesystem.jl | 1 - test/dde.jl | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index bd22241d02..d0e08ca890 100644 --- a/Project.toml +++ b/Project.toml @@ -114,9 +114,10 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" +StochasticDelayDiffEq = "29a0d76e-afc8-11e9-03a4-eda52ae4b960" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"] +test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq"] diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index e436dde725..30b237c807 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -574,7 +574,6 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true}, isdde = true, kwargs...) - @show g_gen[2] g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen) f(u, h, p, t) = f_oop(u, h, p, t) f(du, u, h, p, t) = f_iip(du, u, h, p, t) diff --git a/test/dde.jl b/test/dde.jl index b62c6eb214..cd111c21af 100644 --- a/test/dde.jl +++ b/test/dde.jl @@ -78,6 +78,7 @@ D = Differential(t) eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c + (α * x(t) + γ) * η] @named sys = System(eqs) sys = structural_simplify(sys) +@test equations(sys) == [D(x(t)) ~ a * x(t) + b * x(t - τ) + c] +@test isequal(ModelingToolkit.get_noiseeqs(sys), [α * x(t) + γ;;]) prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,)); -sol_mtk = solve(prob_mtk, RKMil()) -@test sol.u[end] ≈ sol_mtk.u[end] +@test_nowarn sol_mtk = solve(prob_mtk, RKMil())