Skip to content

Commit

Permalink
Merge branch 'master' into cleanup_initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
hersle authored Oct 4, 2024
2 parents 09fda2f + 40b1f7c commit 01a7cf9
Show file tree
Hide file tree
Showing 16 changed files with 369 additions and 272 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ PrecompileTools = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.52.1"
SciMLBase = "2.55"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.29"
SymbolicIndexingInterface = "0.3.31"
SymbolicUtils = "3.7"
Symbolics = "6.12"
URIs = "1"
Expand Down
17 changes: 8 additions & 9 deletions docs/src/basics/MTKLanguage.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ end
@structural_parameters begin
f = sin
N = 2
M = 3
end
begin
v_var = 1.0
end
@variables begin
v(t) = v_var
v_array(t)[1:N, 1:M]
v_array(t)[1:2, 1:3]
v_for_defaults(t)
end
@extend ModelB(; p1)
Expand Down Expand Up @@ -311,10 +310,10 @@ end
- `:defaults`: Dictionary of variables and default values specified in the `@defaults`.
- `:extend`: The list of extended unknowns, name given to the base system, and name of the base system.
- `:structural_parameters`: Dictionary of structural parameters mapped to their metadata.
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. Metadata of
the parameter arrays is, for now, omitted.
- `:variables`: Dictionary of symbolic variables mapped to their metadata. Metadata of
the variable arrays is, for now, omitted.
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For
parameter arrays, length is added to the metadata as `:size`.
- `:variables`: Dictionary of symbolic variables mapped to their metadata. For
variable arrays, length is added to the metadata as `:size`.
- `:kwargs`: Dictionary of keyword arguments mapped to their metadata.
- `:independent_variable`: Independent variable, which is added while generating the Model.
- `:equations`: List of equations (represented as strings).
Expand All @@ -325,10 +324,10 @@ For example, the structure of `ModelC` is:
julia> ModelC.structure
Dict{Symbol, Any} with 10 entries:
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_for_defaults=>Dict(:type=>Real))
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)), :v_for_defaults=>Dict(:type=>Real))
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
:kwargs => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3), :v => Dict{Symbol, Any}(:value => :v_var, :type => Real), :v_for_defaults => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real), :p1 => Dict(:value => nothing)),
:structural_parameters => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3))
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :v_for_defaults=>Dict{Symbol, Union{Nothing, DataType}}(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2))
:independent_variable => t
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
Expand Down
146 changes: 86 additions & 60 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
end

function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
inputs = nothing, history = false)
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
if dvs !== nothing
Expand Down Expand Up @@ -328,6 +329,19 @@ function wrap_array_vars(
array_parameters[p] = (idxs, buffer_idx, sz)
end
end

inputind = if history
uind + 2
else
uind + 1
end
params_offset = if history && hasinputs
uind + 2
elseif history || hasinputs
uind + 1
else
uind
end
if isscalar
function (expr)
Func(
Expand All @@ -336,10 +350,10 @@ function wrap_array_vars(
Let(
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind + hasinputs].name), $v))
[k :(view($(expr.args[inputind].name), $v))
for (k, v) in input_vars],
[k :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
view($(expr.args[params_offset + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k Code.MakeArray(v, symtype(k))
Expand All @@ -358,10 +372,10 @@ function wrap_array_vars(
Let(
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind + hasinputs].name), $v))
[k :(view($(expr.args[inputind].name), $v))
for (k, v) in input_vars],
[k :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
view($(expr.args[params_offset + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k Code.MakeArray(v, symtype(k))
Expand All @@ -380,10 +394,10 @@ function wrap_array_vars(
vcat(
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k :(view($(expr.args[uind + hasinputs + 1].name), $v))
[k :(view($(expr.args[inputind + 1].name), $v))
for (k, v) in input_vars],
[k :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
view($(expr.args[params_offset + buffer_idx + 1].name),
$idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
Expand All @@ -398,50 +412,76 @@ function wrap_array_vars(
end
end

function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool)
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)

"""
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
wrapped is for a scalar value. `p_start` is the index of the argument containing
the first parameter vector in the out-of-place version of the function. For example,
if a history function (DDEs) was passed before `p`, then the function before wrapping
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
The returned function is `identity` if the system does not have an `IndexCache`.
"""
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
offset = Int(is_time_dependent(sys))

if isscalar
function (expr)
p = gensym(:p)
param_args = expr.args[p_start:(end - offset)]
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
param_buffer_args = param_args[param_buffer_idxs]
destructured_mtkparams = DestructuredArgs(
[x.name for x in param_buffer_args],
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
Func(
[
expr.args[1],
DestructuredArgs(
[arg.name for arg in expr.args[2:(end - offset)]], p),
(isone(offset) ? (expr.args[end],) : ())...
expr.args[begin:(p_start - 1)]...,
destructured_mtkparams,
expr.args[(end - offset + 1):end]...
],
[],
Let(expr.args[2:(end - offset)], expr.body, false)
Let(param_buffer_args, expr.body, false)
)
end
else
function (expr)
p = gensym(:p)
param_args = expr.args[p_start:(end - offset)]
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
param_buffer_args = param_args[param_buffer_idxs]
destructured_mtkparams = DestructuredArgs(
[x.name for x in param_buffer_args],
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
Func(
[
expr.args[1],
DestructuredArgs(
[arg.name for arg in expr.args[2:(end - offset)]], p),
(isone(offset) ? (expr.args[end],) : ())...
expr.args[begin:(p_start - 1)]...,
destructured_mtkparams,
expr.args[(end - offset + 1):end]...
],
[],
Let(expr.args[2:(end - offset)], expr.body, false)
Let(param_buffer_args, expr.body, false)
)
end,
function (expr)
p = gensym(:p)
param_args = expr.args[(p_start + 1):(end - offset)]
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
param_buffer_args = param_args[param_buffer_idxs]
destructured_mtkparams = DestructuredArgs(
[x.name for x in param_buffer_args],
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
Func(
[
expr.args[1],
expr.args[2],
DestructuredArgs(
[arg.name for arg in expr.args[3:(end - offset)]], p),
(isone(offset) ? (expr.args[end],) : ())...
expr.args[begin:p_start]...,
destructured_mtkparams,
expr.args[(end - offset + 1):end]...
],
[],
Let(expr.args[3:(end - offset)], expr.body, false)
Let(param_buffer_args, expr.body, false)
)
end
end
Expand Down Expand Up @@ -669,25 +709,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
if rawobs isa Tuple
if is_time_dependent(sys)
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1a(p::MTKParameters, t) = oop(p..., t)
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
f1a(p, t) = oop(p, t)
f1a(out, p, t) = iip(out, p, t)
end
else
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1b(p::MTKParameters) = oop(p...)
f1b(out, p::MTKParameters) = iip(out, p...)
f1b(p) = oop(p)
f1b(out, p) = iip(out, p)
end
end
else
if is_time_dependent(sys)
obsfn = let rawobs = rawobs
f2a(p::MTKParameters, t) = rawobs(p..., t)
end
else
obsfn = let rawobs = rawobs
f2b(p::MTKParameters) = rawobs(p...)
end
end
obsfn = rawobs
end
else
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
Expand Down Expand Up @@ -802,17 +834,11 @@ function SymbolicIndexingInterface.observed(
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)

if is_time_dependent(sys)
return let _fn = _fn
fn1(u, p, t) = _fn(u, p, t)
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
fn1
end
return _fn
else
return let _fn = _fn
fn2(u, p) = _fn(u, p)
fn2(u, p::MTKParameters) = _fn(u, p...)
fn2(::Nothing, p) = _fn([], p)
fn2(::Nothing, p::MTKParameters) = _fn([], p...)
fn2
end
end
Expand All @@ -828,6 +854,8 @@ end
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false

SymbolicIndexingInterface.is_markovian(sys::AbstractSystem) = !is_dde(sys)

SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true

function SymbolicIndexingInterface.all_variable_symbols(sys::AbstractSystem)
Expand Down Expand Up @@ -971,6 +999,7 @@ for prop in [:eqs
:solved_unknowns
:split_idxs
:parent
:is_dde
:index_cache
:is_scalar_noise
:isscheduled]
Expand Down Expand Up @@ -2349,8 +2378,8 @@ function linearization_function(sys::AbstractSystem, inputs,
u_getter = u_getter

function (u, p, t)
p_setter!(oldps, p_getter(u, p..., t))
newu = u_getter(u, p..., t)
p_setter!(oldps, p_getter(u, p, t))
newu = u_getter(u, p, t)
return newu, oldps
end
end
Expand All @@ -2361,20 +2390,15 @@ function linearization_function(sys::AbstractSystem, inputs,

function (u, p, t)
state = ProblemState(; u, p, t)
return u_getter(state), p_getter(state)
return u_getter(
state_values(state), parameter_values(state), current_time(state)),
p_getter(state)
end
end
end
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
initprobmap = let inner = initprobmap
fn(u, p::MTKParameters) = inner(u, p...)
fn(u, p) = inner(u, p)
fn
end
end
ps = parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
Expand Down Expand Up @@ -2421,7 +2445,7 @@ function linearization_function(sys::AbstractSystem, inputs,
fg_xz = ForwardDiff.jacobian(uf, u)
h_xz = ForwardDiff.jacobian(
let p = p, t = t
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
xz -> h(xz, p, t)
end, u)
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
Expand All @@ -2433,7 +2457,6 @@ function linearization_function(sys::AbstractSystem, inputs,
end
hp = let u = u, t = t
_hp(p) = h(u, p, t)
_hp(p::MTKParameters) = h(u, p..., t)
_hp
end
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
Expand Down Expand Up @@ -2486,7 +2509,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
dx = fun(sts, p..., t)

h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
y = h(sts, p..., t)
y = h(sts, p, t)

fg_xz = Symbolics.jacobian(dx, sts)
fg_u = Symbolics.jacobian(dx, inputs)
Expand Down Expand Up @@ -2955,6 +2978,9 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
nsys == 0 && return sys
@set! sys.name = name
@set! sys.systems = [get_systems(sys); systems]
if has_is_dde(sys)
@set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys))
end
return sys
end
function compose(syss...; name = nameof(first(syss)))
Expand Down
Loading

0 comments on commit 01a7cf9

Please sign in to comment.