diff --git a/src/parameters.jl b/src/parameters.jl index 9174ac454f..4339cb7acf 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -61,3 +61,49 @@ macro parameters(xs...) xs, toparam) |> esc end + +function find_types(array) + by = let set = Dict{Any, Int}(), counter = Ref(0) + x -> begin + # t = typeof(x) + + get!(set, typeof(x)) do + # if t == Float64 + # 1 + # else + counter[] += 1 + # end + end + end + end + return by.(array) +end + +function split_parameters_by_type(ps) + if ps === SciMLBase.NullParameters() + return Float64[], [] #use Float64 to avoid Any type warning + else + by = let set = Dict{Any, Int}(), counter = Ref(0) + x -> begin + get!(set, typeof(x)) do + counter[] += 1 + end + end + end + idxs = by.(ps) + split_idxs = [Int[]] + for (i, idx) in enumerate(idxs) + if idx > length(split_idxs) + push!(split_idxs, Int[]) + end + push!(split_idxs[idx], i) + end + tighten_types = x -> identity.(x) + split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs)) + if length(split_ps) == 1 #Tuple not needed, only 1 type + return split_ps[1], split_idxs + else + return (split_ps...,), split_idxs + end + end +end diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index 65939859c9..c6a957a6e5 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -528,7 +528,8 @@ function ODAEProblem{iip}(sys, tspan, parammap = DiffEqBase.NullParameters(); callback = nothing, - use_union = false, + use_union = true, + tofloat = true, check = true, kwargs...) where {iip} eqs = equations(sys) @@ -540,8 +541,7 @@ function ODAEProblem{iip}(sys, defs = ModelingToolkit.mergedefaults(defs, parammap, ps) defs = ModelingToolkit.mergedefaults(defs, u0map, dvs) u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true) - p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union, - use_union) + p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union) has_difference = any(isdifferenceeq, eqs) cbs = process_events(sys; callback, has_difference, kwargs...) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index ae04b67ce5..ebc612a774 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -152,8 +152,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param states = sol_states, kwargs...) else - build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, - kwargs...) + if p isa Tuple + build_function(rhss, u, p..., t; postprocess_fbody = pre, + states = sol_states, + kwargs...) + else + build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, + kwargs...) + end end end end @@ -332,8 +338,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s f_oop, f_iip = eval_expression ? (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) : f_gen - f(u, p, t) = f_oop(u, p, t) - f(du, u, p, t) = f_iip(du, u, p, t) + if p isa Tuple + g(u, p, t) = f_oop(u, p..., t) + g(du, u, p, t) = f_iip(du, u, p..., t) + f = g + else + k(u, p, t) = f_oop(u, p, t) + k(du, u, p, t) = f_iip(du, u, p, t) + f = k + end if specialize === SciMLBase.FunctionWrapperSpecialize && iip if u0 === nothing || p === nothing || t === nothing @@ -384,32 +397,64 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s obs = observed(sys) observedfun = if steady_state - let sys = sys, dict = Dict() + let sys = sys, dict = Dict(), ps = ps function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do build_explicit_observed_function(sys, obsvar) end if args === () let obs = obs - (u, p, t = Inf) -> obs(u, p, t) + (u, p, t = Inf) -> if ps isa Tuple + obs(u, p..., t) + else + obs(u, p, t) + end end else - length(args) == 2 ? obs(args..., Inf) : obs(args...) + if ps isa Tuple + if length(args) == 2 + u, p = args + obs(u, p..., Inf) + else + u, p, t = args + obs(u, p..., t) + end + else + if length(args) == 2 + u, p = args + obs(u, p, Inf) + else + u, p, t = args + obs(u, p, t) + end + end end end end else - let sys = sys, dict = Dict() + let sys = sys, dict = Dict(), ps = ps function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) + build_explicit_observed_function(sys, + obsvar; + checkbounds = checkbounds, + ps) end if args === () let obs = obs - (u, p, t) -> obs(u, p, t) + (u, p, t) -> if ps isa Tuple + obs(u, p..., t) + else + obs(u, p, t) + end end else - obs(args...) + if ps isa Tuple # split parameters + u, p, t = args + obs(u, p..., t) + else + obs(args...) + end end end end @@ -677,15 +722,15 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys), end """ - u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union) + u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true) Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point. """ function get_u0_p(sys, u0map, parammap; - use_union = false, - tofloat = !use_union, + use_union = true, + tofloat = true, symbolic_u0 = false) dvs = states(sys) ps = parameters(sys) @@ -712,8 +757,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; simplify = false, linenumbers = true, parallel = SerialForm(), eval_expression = true, - use_union = false, - tofloat = !use_union, + use_union = true, + tofloat = true, symbolic_u0 = false, kwargs...) eqs = equations(sys) @@ -721,7 +766,18 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; ps = parameters(sys) iv = get_iv(sys) - u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0) + u0, p, defs = get_u0_p(sys, + u0map, + parammap; + tofloat, + use_union, + symbolic_u0) + + p, split_idxs = split_parameters_by_type(p) + if p isa Tuple + ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) + ps = (ps...,) #if p is Tuple, ps should be Tuple + end if implicit_dae && du0map !== nothing ddvs = map(Differential(iv), dvs) @@ -738,7 +794,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, checkbounds = checkbounds, p = p, linenumbers = linenumbers, parallel = parallel, simplify = simplify, - sparse = sparse, eval_expression = eval_expression, kwargs...) + sparse = sparse, eval_expression = eval_expression, + kwargs...) implicit_dae ? (f, du0, u0, p) : (f, u0, p) end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index c123f7b15c..dcbecfcfb0 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -314,6 +314,7 @@ function build_explicit_observed_function(sys, ts; output_type = Array, checkbounds = true, drop_expr = drop_expr, + ps = parameters(sys), throw = true) if (isscalar = !(ts isa AbstractVector)) ts = [ts] @@ -385,17 +386,20 @@ function build_explicit_observed_function(sys, ts; push!(obsexprs, lhs ← rhs) end - pars = parameters(sys) if inputs !== nothing - pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list + ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list + end + if ps isa Tuple + ps = DestructuredArgs.(ps, inbounds = !checkbounds) + else + ps = (DestructuredArgs(ps, inbounds = !checkbounds),) end - ps = DestructuredArgs(pars, inbounds = !checkbounds) dvs = DestructuredArgs(states(sys), inbounds = !checkbounds) if inputs === nothing - args = [dvs, ps, ivs...] + args = [dvs, ps..., ivs...] else ipts = DestructuredArgs(inputs, inbounds = !checkbounds) - args = [dvs, ipts, ps, ivs...] + args = [dvs, ipts, ps..., ivs...] end pre = get_postprocess_fbody(sys) diff --git a/src/utils.jl b/src/utils.jl index 29ba9b19ab..abea65ab21 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -219,6 +219,10 @@ end hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue) getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue)) +function getdefaulttype(v) + def = value(getmetadata(unwrap(v), Symbolics.VariableDefaultValue, nothing)) + def === nothing ? Float64 : typeof(def) +end function setdefault(v, val) val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val)) end @@ -642,10 +646,15 @@ end throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list.")) end -function promote_to_concrete(vs; tofloat = true, use_union = false) +function promote_to_concrete(vs; tofloat = true, use_union = true) if isempty(vs) return vs end + if vs isa Tuple #special rule, if vs is a Tuple, preserve types, container converted to Array + tofloat = false + use_union = true + vs = Any[vs...] + end T = eltype(vs) if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do vs @@ -656,6 +665,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = false) I = Int8 has_int = false has_array = false + has_bool = false array_T = nothing for v in vs if v isa AbstractArray @@ -668,6 +678,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false) has_int = true I = promote_type(I, E) end + if E <: Bool + has_bool = true + end end if tofloat && !has_array C = float(C) @@ -678,6 +691,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false) if has_int C = Union{C, I} end + if has_bool + C = Union{C, Bool} + end return copyto!(similar(vs, C), vs) end convert.(C, vs) diff --git a/src/variables.jl b/src/variables.jl index f1f7edb1df..4cffd3fdeb 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -58,7 +58,7 @@ applicable. """ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, toterm = default_toterm, promotetoconcrete = nothing, - tofloat = true, use_union = false) + tofloat = true, use_union = true) varlist = collect(map(unwrap, varlist)) # Edge cases where one of the arguments is effectively empty. @@ -75,9 +75,10 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, end end - T = typeof(varmap) - # We respect the input type - container_type = T <: Dict ? Array : T + # T = typeof(varmap) + # We respect the input type (feature removed, not needed with Tuple support) + # container_type = T <: Union{Dict,Tuple} ? Array : T + container_type = Array vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs varmap = todict(varmap) diff --git a/test/odesystem.jl b/test/odesystem.jl index e8b08fe4a2..6d82e97d46 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -734,18 +734,28 @@ let u0map = [A => 1.0] pmap = (k1 => 1.0, k2 => 1) tspan = (0.0, 1.0) + prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false) + @test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))]) + prob = ODEProblem(sys, u0map, tspan, pmap) - @test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))]) + @test prob.p isa Vector{Float64} pmap = [k1 => 1, k2 => 1] tspan = (0.0, 1.0) prob = ODEProblem(sys, u0map, tspan, pmap) @test eltype(prob.p) === Float64 - pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0] - tspan = (0.0, 1.0) - prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true) - @test eltype(prob.p) === Union{Float64, Int} + prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false) + @test eltype(prob.p) === Int + + prob = ODEProblem(sys, u0map, tspan, pmap) + @test prob.p isa Vector{Float64} + + # No longer supported, Tuple used instead + # pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0] + # tspan = (0.0, 1.0) + # prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true) + # @test eltype(prob.p) === Union{Float64, Int} end let diff --git a/test/runtests.jl b/test/runtests.jl index 79321c977f..ab847bb9a6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ using SafeTestsets, Test @safetestset "JumpSystem Test" include("jumpsystem.jl") @safetestset "Constraints Test" include("constraints.jl") @safetestset "Reduction Test" include("reduction.jl") +@safetestset "Split Parameters Test" include("split_parameters.jl") @safetestset "ODAEProblem Test" include("odaeproblem.jl") @safetestset "Components Test" include("components.jl") @safetestset "Model Parsing Test" include("model_parsing.jl") diff --git a/test/split_parameters.jl b/test/split_parameters.jl new file mode 100644 index 0000000000..ef8f434dca --- /dev/null +++ b/test/split_parameters.jl @@ -0,0 +1,79 @@ +using ModelingToolkit, Test +using ModelingToolkitStandardLibrary.Blocks +using OrdinaryDiffEq + +# ------------------------ Mixed Single Values and Vector + +dt = 4e-4 +t_end = 10.0 +time = 0:dt:t_end +x = @. time^2 + 1.0 + +@parameters t +D = Differential(t) + +get_value(data, t, dt) = data[round(Int, t / dt + 1)] +@register_symbolic get_value(data, t, dt) + +function Sampled(; name, data = Float64[], dt = 0.0) + pars = @parameters begin + data = data + dt = dt + end + + vars = [] + systems = @named begin + output = RealOutput() + end + + eqs = [ + output.u ~ get_value(data, t, dt), + ] + + return ODESystem(eqs, t, vars, pars; name, systems, + defaults = [output.u => data[1]]) +end + +vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 +@named src = Sampled(; data = Float64[], dt) +@named int = Integrator() + +eqs = [y ~ src.output.u + D(y) ~ dy + D(dy) ~ ddy + connect(src.output, int.input)] + +@named sys = ODESystem(eqs, t, vars, []; systems = [int, src]) +s = complete(sys) +sys = structural_simplify(sys) +prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]) +@test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}} +sol = solve(prob, ImplicitEuler()); +@test sol.retcode == ReturnCode.Success +@test sol[y][end] == x[end] + +# ------------------------ Mixed Type Converted to float (default behavior) + +vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 +pars = @parameters a=1.0 b=2.0 c=3 +eqs = [D(y) ~ dy * a + D(dy) ~ ddy * b + ddy ~ sin(t) * c] + +@named model = ODESystem(eqs, t, vars, pars) +sys = structural_simplify(model) + +tspan = (0.0, t_end) +prob = ODEProblem(sys, [], tspan, []) + +@test prob.p isa Vector{Float64} +sol = solve(prob, ImplicitEuler()); +@test sol.retcode == ReturnCode.Success + +# ------------------------ Mixed Type Conserved + +prob = ODEProblem(sys, [], tspan, []; tofloat = false) + +@test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} +sol = solve(prob, ImplicitEuler()); +@test sol.retcode == ReturnCode.Success