Skip to content

Commit

Permalink
Remove sim.u
Browse files Browse the repository at this point in the history
  • Loading branch information
ctessum committed Aug 8, 2024
1 parent 0c80db0 commit db5c13a
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 97 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarthSciMLBase"
uuid = "e53f1632-a13c-4728-9402-0c66d48804b0"
authors = ["EarthSciML Authors and Contributors"]
version = "0.13.0"
version = "0.14.0"

[deps]
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
Expand Down
28 changes: 16 additions & 12 deletions docs/src/simulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ Next, we need to define a method of `EarthSciMLBase.get_scimlop` for our operato
function EarthSciMLBase.get_scimlop(op::ExampleOp, s::Simulator)
obs_f = s.obs_fs[s.obs_fs_idx[op.α]]
function run(du, u, p, t)
u = reshape(u, size(s.u)...)
du = reshape(du, size(s.u)...)
for ix ∈ 1:size(s.u, 1)
u = reshape(u, size(s)...)
du = reshape(du, size(s)...)
for ix ∈ 1:size(u, 1)
for (i, c1) ∈ enumerate(s.grid[1])
for (j, c2) ∈ enumerate(s.grid[2])
for (k, c3) ∈ enumerate(s.grid[3])
Expand All @@ -70,7 +70,8 @@ function EarthSciMLBase.get_scimlop(op::ExampleOp, s::Simulator)
end
nothing
end
FunctionOperator(run, s.u[:], p=s.p)
indata = zeros(EarthSciMLBase.utype(s.domaininfo), size(s))
FunctionOperator(run, indata[:], p=s.p)
end
```
The function above also doesn't have any physical meaning, but it demonstrates some functionality of the `Simulator` "`s`".
Expand Down Expand Up @@ -118,7 +119,6 @@ Next, initialize our operator, giving the the `windspeed` observed variable, and
op = ExampleOp(sys.windspeed)
csys = couple(sys, op, domain)
nothing #hide
```

...and then create a Simulator.
Expand All @@ -127,7 +127,6 @@ coordinates, which we set as 0.1π, 0.1π, and 1, respectively.

```@example sim
sim = Simulator(csys, [0.1π, 0.1π, 1])
nothing #hide
```

Finally, we can choose a [`EarthSciMLBase.SimulatorStrategy`](@ref) and run the simulation.
Expand All @@ -139,14 +138,19 @@ We also choose a time step of 1.0 seconds:
```@example sim
st = SimulatorStrangThreads(Tsit5(), Euler(), 1.0)
@time run!(sim, st)
sol = run!(sim, st)
nothing #hide
```

After the simulation finishes, we can plot the result at the end of the simulation:
After the simulation finishes, we can plot the result:

```@example sim
plot(
heatmap(sim.u[1, :, :, 1]),
heatmap(sim.u[1, :, :, 1]),
)
anim = @animate for i ∈ 1:length(sol.u)
u = reshape(sol.u[i], size(sim)...)
plot(
heatmap(u[1, :, :, 1]),
heatmap(u[1, :, :, 1]),
)
end
gif(anim, fps = 15)
```
30 changes: 12 additions & 18 deletions src/simulator.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export Simulator
export Simulator, init_u

"""
$(TYPEDSIGNATURES)
Expand All @@ -19,8 +19,6 @@ struct Simulator{T,FT1,FT2,TG}
sys_mtk::ODESystem
"Information about the spatiotemporal simulation domain"
domaininfo::DomainInfo{T}
"The system state"
u::Array{T,4}
"The system parameter values"
p::Vector{T}
"The initial values of the system state variables"
Expand Down Expand Up @@ -78,32 +76,28 @@ struct Simulator{T,FT1,FT2,TG}
grd = grid(sys.domaininfo, Δs)
TG = typeof(grd)

u = Array{T}(undef, length(uvals), length(grd[1]), length(grd[2]), length(grd[3]))

new{T,typeof(obs_fs),typeof(tf_fs),TG}(sys, mtk_sys, sys.domaininfo, u, pvals, uvals, pvidx, grd, tuple(Δs...), obs_fs, obs_fs_idx, tf_fs)
new{T,typeof(obs_fs),typeof(tf_fs),TG}(sys, mtk_sys, sys.domaininfo, pvals, uvals, pvidx, grd, tuple(Δs...), obs_fs, obs_fs_idx, tf_fs)
end
end

function Base.show(io::IO, s::Simulator)
print(io, "Simulator{$(eltype(s.u))} with $(length(equations(s.sys_mtk))) equation(s), $(length(s.sys.ops)) operator(s), and $(length(s.u)) grid cells.")
print(io, "Simulator{$(utype(s.domaininfo))} with $(length(equations(s.sys_mtk))) equation(s), $(length(s.sys.ops)) operator(s), and $(*([length(g) for g in s.grid]...)) grid cells.")
end

"Initialize the state variables."
function init_u!(s::Simulator)
function init_u(s::Simulator{T}) where T
u = Array{T}(undef, size(s)...)
# Set initial conditions
for i eachindex(s.u_init)
for j eachindex(s.grid[1])
for k eachindex(s.grid[2])
for l eachindex(s.grid[3])
s.u[i, j, k, l] = s.u_init[i]
end
end
end
for i eachindex(s.u_init), j eachindex(s.grid[1]), k eachindex(s.grid[2]), l eachindex(s.grid[3])
u[i, j, k, l] = s.u_init[i]
end
nothing
u
end

function get_callbacks(s::Simulator)
extra_cb = [init_callback(c, s) for c s.sys.init_callbacks]
[s.sys.callbacks; extra_cb]
end
end

Base.size(s::Simulator) = (length(states(s.sys_mtk)), [length(g) for g s.grid]...)
Base.length(s::Simulator) = *(size(s)...)
20 changes: 11 additions & 9 deletions src/simulator_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ SimulatorStrategy is an abstract type that defines the strategy for running a si
Each SimulatorStrategy should implement a method of:
```julia
run!(st::SimulatorStrategy, s::Simulator{T}) where T
run!(st::SimulatorStrategy, s::Simulator{T}, u0) where T
```
where u0 is the initial conditions for the system state.
"""
abstract type SimulatorStrategy end

Expand All @@ -22,7 +24,7 @@ end
"Return a SciMLOperator to apply the MTK system to each column of s.u after reshaping to a matrix."
function mtk_op(s::Simulator)
mtkf = ODEFunction(s.sys_mtk)
II = CartesianIndices(size(s.u)[2:4])
II = CartesianIndices(size(s)[2:4])
function setp!(p, j) # Set the parameters for the jth grid cell.
ii = II[j]
for (jj, g) enumerate(s.grid) # Set the coordinates of this grid cell.
Expand All @@ -45,18 +47,19 @@ function mtk_op(s::Simulator)
@inbounds @views mapreduce(jcol -> ff(jcol[2], p, t, jcol[1]), hcat, enumerate(eachcol(u)))
end

indata = reshape(s.u, size(s.u, 1), :)
u = zeros(utype(s.domaininfo), size(s)...)
indata = reshape(u, size(s)[1], :)
fo = FunctionOperator(f, indata, batch=true, p=s.p)

ncols = size(indata, 2)
# Rehape the input vector to a matrix, then apply the FunctionOperator.
#op = ScalarOperator(1.0) * TensorProductOperator(I(ncols), fo)
op = TensorProductOperator(I(ncols), fo)
cache_operator(op, s.u[:])
cache_operator(op, u[:])
end

function mtk_func(s::Simulator)
b = repeat([length(states(s.sys_mtk))], length(s.u) ÷ size(s.u, 1))
b = repeat([length(states(s.sys_mtk))], length(s) ÷ size(s)[1])
j = BlockBandedMatrix{Float64}(undef, b, b, (0,0)) # Jacobian prototype
ODEFunction(mtk_op(s); jac_prototype=j)
end
Expand Down Expand Up @@ -87,7 +90,7 @@ $(TYPEDSIGNATURES)
Run the simulation.
`kwargs` are passed to the ODEProblem and ODE solver constructors.
"""
function run!(s::Simulator, st::SimulatorIMEX; kwargs...)
function run!(s::Simulator, st::SimulatorIMEX, u=init_u(s); kwargs...)
f1 = mtk_func(s)

@assert length(s.sys.ops) > 0 "Operators must be defined to use the `SimulatorIMEX` strategy. For no operators, try `SimulatorFused` instead."
Expand All @@ -96,7 +99,6 @@ function run!(s::Simulator, st::SimulatorIMEX; kwargs...)
f2 = sum([get_scimlop(op, s) for op s.sys.ops])

start, finish = time_range(s.domaininfo)
prob = SplitODEProblem(f1, f2, s.u, (start, finish), s.p, callback=CallbackSet(get_callbacks(s)), kwargs...)
solve(prob, st.alg, save_on=false, save_start=false, save_end=false,
initialize_save=false; kwargs...)
prob = SplitODEProblem(f1, f2, u, (start, finish), s.p, callback=CallbackSet(get_callbacks(s)), kwargs...)
solve(prob, st.alg; kwargs...)
end
15 changes: 6 additions & 9 deletions src/simulator_strategy_strang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ $(TYPEDSIGNATURES)
Run the simualation.
`kwargs` are passed to the ODEProblem and ODE solver constructors.
"""
function run!(s::Simulator{T}, st::SimulatorStrang; kwargs...) where {T}
II = CartesianIndices(size(s.u)[2:4])
function run!(s::Simulator{T}, st::SimulatorStrang, u=init_u(s); kwargs...) where {T}
II = CartesianIndices(size(u)[2:4])
IIchunks = collect(Iterators.partition(II, length(II) ÷ nthreads(st)))
start, finish = time_range(s.domaininfo)
prob = ODEProblem(s.sys_mtk, [], (start, finish), []; kwargs...)
Expand All @@ -81,18 +81,15 @@ function run!(s::Simulator{T}, st::SimulatorStrang; kwargs...) where {T}

# Combine the non-stiff operators into a single operator.
# This works because SciMLOperators can be added together.
nonstiff_op = length(s.sys.ops) > 0 ? sum([get_scimlop(op, s) for op s.sys.ops]) : NullOperator(length(s.u))
nonstiff_op = cache_operator(nonstiff_op, s.u)

init_u!(s)
nonstiff_op = length(s.sys.ops) > 0 ? sum([get_scimlop(op, s) for op s.sys.ops]) : NullOperator(length(u))
nonstiff_op = cache_operator(nonstiff_op, u)

cb = CallbackSet(
stiff_callback(s, st, IIchunks, stiff_integrators),
get_callbacks(s)...,
)
@views nonstiff_prob = ODEProblem(nonstiff_op, s.u[:], (start, finish), s.p, callback=cb; kwargs...)
solve(nonstiff_prob, st.nonstiffalg, save_on=false, save_start=false, save_end=false,
initialize_save=false, dt=st.timestep; kwargs...)
@views nonstiff_prob = ODEProblem(nonstiff_op, u[:], (start, finish), s.p, callback=cb; kwargs...)
solve(nonstiff_prob, st.nonstiffalg, dt=st.timestep; kwargs...)
end

"A callback to periodically run the stiff solver."
Expand Down
81 changes: 33 additions & 48 deletions test/simulator_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ end
function EarthSciMLBase.get_scimlop(op::ExampleOp, s::Simulator)
obs_f = s.obs_fs[s.obs_fs_idx[op.α]]
function run(du, u, p, t)
u = reshape(u, size(s.u)...)
du = reshape(du, size(s.u)...)
for ix 1:size(s.u, 1)
u = reshape(u, size(s)...)
du = reshape(du, size(s)...)
for ix 1:size(u, 1)
for (i, c1) enumerate(s.grid[1])
for (j, c2) enumerate(s.grid[2])
for (k, c3) enumerate(s.grid[3])
Expand All @@ -32,7 +32,8 @@ function EarthSciMLBase.get_scimlop(op::ExampleOp, s::Simulator)
end
nothing
end
FunctionOperator(run, s.u[:], p=s.p)
indata = zeros(EarthSciMLBase.utype(s.domaininfo), size(s))
FunctionOperator(run, indata[:], p=s.p)
end

t_min = 0.0
Expand Down Expand Up @@ -65,20 +66,7 @@ sys = ODESystem(eqs, t, name=:Test₊sys)

op = ExampleOp(sys.windspeed)

# Callback for saving the end result for testing.
mutable struct SaveEndCB
u
end
function cb(s::SaveEndCB)
DiscreteCallback(
(u, t, integrator) -> false,
(integrator) -> nothing,
finalize = (c, u, t, integrator) -> s.u = u
)
end

result = SaveEndCB(nothing)
csys = couple(sys, op, domain, cb(result))
csys = couple(sys, op, domain)

sim = Simulator(csys, [0.1, 0.1, 1])
st = SimulatorStrangThreads(Tsit5(), Euler(), 1.0)
Expand All @@ -91,9 +79,10 @@ st = SimulatorStrangThreads(Tsit5(), Euler(), 1.0)
@test sim.obs_fs[sim.obs_fs_idx[op.α]](0.0, 1.0, 3.0, 2.0) == 6.0

scimlop = EarthSciMLBase.get_scimlop(op, sim)
du = similar(sim.u)
u = init_u(sim)
du = similar(u)
du .= 0
@views scimlop(du[:], sim.u[:], sim.p, 0.0)
@views scimlop(du[:], u[:], sim.p, 0.0)

@test sum(abs.(du)) 26094.203039436292

Expand All @@ -103,10 +92,10 @@ prob = ODEProblem(structural_simplify(sys), [], (0.0, 1.0), [
sol1 = solve(prob, Tsit5(), abstol=1e-12, reltol=1e-12)
@test sol1.u[end] [-27.15156429366082, -26.264264199779465]

EarthSciMLBase.init_u!(sim)
u = init_u(sim)

IIchunks, integrators = let
II = CartesianIndices(size(sim.u)[2:4])
II = CartesianIndices(size(u)[2:4])
IIchunks = collect(Iterators.partition(II, length(II) ÷ st.threads))
start, finish = EarthSciMLBase.time_range(sim.domaininfo)
prob = ODEProblem(sim.sys_mtk, [], (start, finish), [])
Expand All @@ -116,42 +105,42 @@ IIchunks, integrators = let
(IIchunks, integrators)
end

EarthSciMLBase.threaded_ode_step!(sim, sim.u, IIchunks, integrators, 0.0, 1.0)
EarthSciMLBase.threaded_ode_step!(sim, u, IIchunks, integrators, 0.0, 1.0)

@test sim.u[1, 1, 1, 1] sol1.u[end][1]
@test sim.u[2, 1, 1, 1] sol1.u[end][2]
@test u[1, 1, 1, 1] sol1.u[end][1]
@test u[2, 1, 1, 1] sol1.u[end][2]

@test sum(abs.(sim.u)) 212733.04492722102
@test sum(abs.(u)) 212733.04492722102

@testset "mtk_func" begin
ucopy = copy(sim.u)
#@testset "mtk_func" begin
begin
ucopy = copy(u)
f = EarthSciMLBase.mtk_func(sim)
EarthSciMLBase.init_u!(sim)
du = similar(sim.u)
prob = ODEProblem(f, sim.u[:], (0.0, 1.0), sim.p)
u = EarthSciMLBase.init_u(sim)
du = similar(u)
prob = ODEProblem(f, u[:], (0.0, 1.0), sim.p)
sol = solve(prob, KenCarp47(linsolve=KrylovJL_GMRES(), autodiff=false))
uu = reshape(sol.u[end], size(ucopy)...)
@test uu[:] ucopy[:] rtol = 0.01
end

run!(sim, st; abstol=1e-12, reltol=1e-12)
sol = run!(sim, st; abstol=1e-12, reltol=1e-12)

@test sum(abs.(result.u)) 3.77224671877136e7 rtol = 1e-3
@test sum(abs.(sol.u[end])) 3.77224671877136e7 rtol = 1e-3

@testset "Float32" begin
domain = DomainInfo(
partialderivatives_δxyδlonlat,
constIC(16.0, indepdomain), constBC(16.0, partialdomains...);
dtype=Float32)

result.u = nothing
csys = couple(sys, op, domain, cb(result))
csys = couple(sys, op, domain)

sim = Simulator(csys, [0.1, 0.1, 1])

run!(sim, st)
sol = run!(sim, st)

@test sum(abs.(result.u)) 3.77224671877136e7
@test sum(abs.(sol.u[end])) 3.77224671877136e7
end

@testset "No operator" begin
Expand All @@ -160,29 +149,25 @@ end
constIC(16.0, indepdomain), constBC(16.0, partialdomains...);
dtype=Float32)

result.u = nothing
csys = couple(sys, domain, cb(result))
csys = couple(sys, domain)

sim = Simulator(csys, [0.1, 0.1, 1])

run!(sim, st; abstol=1e-6, reltol=1e-6)
sol = run!(sim, st; abstol=1e-6, reltol=1e-6)

@test sum(abs.(result.u)) 3.8660308f7
@test sum(abs.(sol.u[end])) 3.8660308f7
end

@testset "SimulatorStrategies" begin
st = SimulatorStrangThreads(Tsit5(), Euler(), 1.0)
result.u = nothing
run!(sim, st; abstol=1e-12, reltol=1e-12)
@test sum(abs.(result.u)) 3.77224671877136e7 rtol = 1e-3
sol = run!(sim, st; abstol=1e-12, reltol=1e-12)
@test sum(abs.(sol.u[end])) 3.77224671877136e7 rtol = 1e-3

st = SimulatorStrangSerial(Tsit5(), Euler(), 1.0)
result.u = nothing
run!(sim, st; abstol=1e-12, reltol=1e-12)
@test sum(abs.(result.u)) 3.77224671877136e7 rtol = 1e-3
sol = run!(sim, st; abstol=1e-12, reltol=1e-12)
@test sum(abs.(sol.u[end])) 3.77224671877136e7 rtol = 1e-3

st = SimulatorIMEX(KenCarp47(linsolve=KrylovJL_GMRES(), autodiff=false))
result.u = nothing
@test_broken run!(sim, st)
end

Expand Down

2 comments on commit db5c13a

@ctessum
Copy link
Member Author

@ctessum ctessum commented on db5c13a Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/112675

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.0 -m "<description of version>" db5c13ae74ba55f3554f91029c5ef3811c2f913e
git push origin v0.14.0

Please sign in to comment.