Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle inhomogeneous parameters using a Tuple of Vectors #2231

Merged
merged 19 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ $(DocStringExtensions.README)
"""
module ModelingToolkit
using PrecompileTools, Reexport
@recompile_invalidations begin
@recompile_invalidations begin
using DocStringExtensions
using Compat
using AbstractTrees
Expand Down
46 changes: 46 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,49 @@
xs,
toparam) |> esc
end

function find_types(array)
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin

Check warning on line 67 in src/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
# t = typeof(x)

get!(set, typeof(x)) do

Check warning on line 70 in src/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters.jl#L70

Added line #L70 was not covered by tests
# if t == Float64
# 1
# else
counter[] += 1

Check warning on line 74 in src/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters.jl#L74

Added line #L74 was not covered by tests
# end
end
end
end
return by.(array)

Check warning on line 79 in src/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters.jl#L79

Added line #L79 was not covered by tests
end

function split_parameters_by_type(ps)
if ps === SciMLBase.NullParameters()
return Float64[], [] #use Float64 to avoid Any type warning
else
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
get!(set, typeof(x)) do
counter[] += 1
end
end
end
idxs = by.(ps)
split_idxs = [Int[]]
for (i, idx) in enumerate(idxs)
if idx > length(split_idxs)
push!(split_idxs, Int[])
end
push!(split_idxs[idx], i)
end
tighten_types = x -> identity.(x)
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
return (split_ps...,), split_idxs
end
end
end
6 changes: 3 additions & 3 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ function ODAEProblem{iip}(sys,
tspan,
parammap = DiffEqBase.NullParameters();
callback = nothing,
use_union = false,
use_union = true,
tofloat = true,
check = true,
kwargs...) where {iip}
eqs = equations(sys)
Expand All @@ -540,8 +541,7 @@ function ODAEProblem{iip}(sys,
defs = ModelingToolkit.mergedefaults(defs, parammap, ps)
defs = ModelingToolkit.mergedefaults(defs, u0map, dvs)
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union,
use_union)
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)

has_difference = any(isdifferenceeq, eqs)
cbs = process_events(sys; callback, has_difference, kwargs...)
Expand Down
93 changes: 75 additions & 18 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,14 @@
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
if p isa Tuple
build_function(rhss, u, p..., t; postprocess_fbody = pre,
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
end
end
end
end
Expand Down Expand Up @@ -332,8 +338,15 @@
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
if p isa Tuple
g(u, p, t) = f_oop(u, p..., t)
g(du, u, p, t) = f_iip(du, u, p..., t)
f = g
else
k(u, p, t) = f_oop(u, p, t)
k(du, u, p, t) = f_iip(du, u, p, t)
f = k
end

if specialize === SciMLBase.FunctionWrapperSpecialize && iip
if u0 === nothing || p === nothing || t === nothing
Expand Down Expand Up @@ -384,32 +397,64 @@

obs = observed(sys)
observedfun = if steady_state
let sys = sys, dict = Dict()
let sys = sys, dict = Dict(), ps = ps
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
end
if args === ()
let obs = obs
(u, p, t = Inf) -> obs(u, p, t)
(u, p, t = Inf) -> if ps isa Tuple
obs(u, p..., t)

Check warning on line 408 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L407-L408

Added lines #L407 - L408 were not covered by tests
else
obs(u, p, t)

Check warning on line 410 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L410

Added line #L410 was not covered by tests
end
end
else
length(args) == 2 ? obs(args..., Inf) : obs(args...)
if ps isa Tuple
if length(args) == 2
u, p = args
obs(u, p..., Inf)

Check warning on line 417 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L415-L417

Added lines #L415 - L417 were not covered by tests
else
u, p, t = args
obs(u, p..., t)

Check warning on line 420 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L419-L420

Added lines #L419 - L420 were not covered by tests
end
else
if length(args) == 2
u, p = args
obs(u, p, Inf)
else
u, p, t = args
obs(u, p, t)

Check warning on line 428 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L427-L428

Added lines #L427 - L428 were not covered by tests
end
end
end
end
end
else
let sys = sys, dict = Dict()
let sys = sys, dict = Dict(), ps = ps
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
build_explicit_observed_function(sys,
obsvar;
checkbounds = checkbounds,
ps)
end
if args === ()
let obs = obs
(u, p, t) -> obs(u, p, t)
(u, p, t) -> if ps isa Tuple
obs(u, p..., t)

Check warning on line 446 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L446

Added line #L446 was not covered by tests
else
obs(u, p, t)
end
end
else
obs(args...)
if ps isa Tuple # split parameters
u, p, t = args
obs(u, p..., t)
else
obs(args...)
end
end
end
end
Expand Down Expand Up @@ -677,15 +722,15 @@
end

"""
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)

Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
"""
function get_u0_p(sys,
u0map,
parammap;
use_union = false,
tofloat = !use_union,
use_union = true,
tofloat = true,
symbolic_u0 = false)
dvs = states(sys)
ps = parameters(sys)
Expand All @@ -712,16 +757,27 @@
simplify = false,
linenumbers = true, parallel = SerialForm(),
eval_expression = true,
use_union = false,
tofloat = !use_union,
use_union = true,
tofloat = true,
symbolic_u0 = false,
kwargs...)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
iv = get_iv(sys)

u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)
u0, p, defs = get_u0_p(sys,
u0map,
parammap;
tofloat,
use_union,
symbolic_u0)

p, split_idxs = split_parameters_by_type(p)
if p isa Tuple
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
end

if implicit_dae && du0map !== nothing
ddvs = map(Differential(iv), dvs)
Expand All @@ -738,7 +794,8 @@
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression, kwargs...)
sparse = sparse, eval_expression = eval_expression,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end

Expand Down
14 changes: 9 additions & 5 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ function build_explicit_observed_function(sys, ts;
output_type = Array,
checkbounds = true,
drop_expr = drop_expr,
ps = parameters(sys),
throw = true)
if (isscalar = !(ts isa AbstractVector))
ts = [ts]
Expand Down Expand Up @@ -385,17 +386,20 @@ function build_explicit_observed_function(sys, ts;
push!(obsexprs, lhs ← rhs)
end

pars = parameters(sys)
if inputs !== nothing
pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
end
if ps isa Tuple
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
else
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
end
ps = DestructuredArgs(pars, inbounds = !checkbounds)
dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
if inputs === nothing
args = [dvs, ps, ivs...]
args = [dvs, ps..., ivs...]
else
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
args = [dvs, ipts, ps, ivs...]
args = [dvs, ipts, ps..., ivs...]
end
pre = get_postprocess_fbody(sys)

Expand Down
18 changes: 17 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@

hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))
function getdefaulttype(v)
def = value(getmetadata(unwrap(v), Symbolics.VariableDefaultValue, nothing))
def === nothing ? Float64 : typeof(def)

Check warning on line 224 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L222-L224

Added lines #L222 - L224 were not covered by tests
end
function setdefault(v, val)
val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val))
end
Expand Down Expand Up @@ -642,10 +646,15 @@
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list."))
end

function promote_to_concrete(vs; tofloat = true, use_union = false)
function promote_to_concrete(vs; tofloat = true, use_union = true)
if isempty(vs)
return vs
end
if vs isa Tuple #special rule, if vs is a Tuple, preserve types, container converted to Array
tofloat = false
use_union = true
vs = Any[vs...]
end
T = eltype(vs)
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
vs
Expand All @@ -656,6 +665,7 @@
I = Int8
has_int = false
has_array = false
has_bool = false
array_T = nothing
for v in vs
if v isa AbstractArray
Expand All @@ -668,6 +678,9 @@
has_int = true
I = promote_type(I, E)
end
if E <: Bool
has_bool = true
end
end
if tofloat && !has_array
C = float(C)
Expand All @@ -678,6 +691,9 @@
if has_int
C = Union{C, I}
end
if has_bool
C = Union{C, Bool}

Check warning on line 695 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L695

Added line #L695 was not covered by tests
end
return copyto!(similar(vs, C), vs)
end
convert.(C, vs)
Expand Down
9 changes: 5 additions & 4 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ applicable.
"""
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
toterm = default_toterm, promotetoconcrete = nothing,
tofloat = true, use_union = false)
tofloat = true, use_union = true)
varlist = collect(map(unwrap, varlist))

# Edge cases where one of the arguments is effectively empty.
Expand All @@ -75,9 +75,10 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
end
end

T = typeof(varmap)
# We respect the input type
container_type = T <: Dict ? Array : T
# T = typeof(varmap)
# We respect the input type (feature removed, not needed with Tuple support)
# container_type = T <: Union{Dict,Tuple} ? Array : T
container_type = Array

vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
varmap = todict(varmap)
Expand Down
20 changes: 15 additions & 5 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -734,18 +734,28 @@ let
u0map = [A => 1.0]
pmap = (k1 => 1.0, k2 => 1)
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
@test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])

prob = ODEProblem(sys, u0map, tspan, pmap)
@test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
@test prob.p isa Vector{Float64}

pmap = [k1 => 1, k2 => 1]
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap)
@test eltype(prob.p) === Float64

pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
@test eltype(prob.p) === Union{Float64, Int}
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
@test eltype(prob.p) === Int

prob = ODEProblem(sys, u0map, tspan, pmap)
@test prob.p isa Vector{Float64}

# No longer supported, Tuple used instead
# pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
# tspan = (0.0, 1.0)
# prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
# @test eltype(prob.p) === Union{Float64, Int}
end

let
Expand Down
Loading
Loading