Skip to content

Commit

Permalink
Merge pull request #47 from ctessum/patch-1
Browse files Browse the repository at this point in the history
Attempt to improve type stability
  • Loading branch information
zsunberg authored Jul 11, 2024
2 parents a955ae2 + 84e0e0b commit 07bfa02
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 185 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
julia = "1"
StaticArrays = "0.9, 0.10, 0.11, 0.12, 1"
julia = "1"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
125 changes: 66 additions & 59 deletions src/GridInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ mutable struct RectangleGrid{D} <: AbstractGrid{D}
cut_counts::Vector{Int}
cuts::Vector{Float64}

function RectangleGrid{D}(cutPoints...) where D
function RectangleGrid{D}(cutPoints...) where {D}
cut_counts = Int[length(cutPoints[i]) for i = 1:length(cutPoints)]
cuts = vcat(cutPoints...)
myCutPoints = Array{Vector{Float64}}(undef, length(cutPoints))
numDims = length(cutPoints)
@assert numDims == D
for i = 1:numDims
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)", i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
Expand All @@ -43,15 +43,15 @@ mutable struct SimplexGrid{D} <: AbstractGrid{D}
ilo::Vector{Int} # indices of cuts below point
n_ind::Vector{Int}

function SimplexGrid{D}(cutPoints...) where D
function SimplexGrid{D}(cutPoints...) where {D}
cut_counts = Int[length(cutPoints[i]) for i = 1:length(cutPoints)]
cuts = vcat(cutPoints...)
myCutPoints = Array{Vector{Float64}}(undef, length(cutPoints))
numDims = length(cutPoints)
@assert numDims == D
for i = 1:numDims
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)", i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
Expand All @@ -72,8 +72,8 @@ Base.length(grid::RectangleGrid) = prod(grid.cut_counts)
Base.size(grid::RectangleGrid) = Tuple(grid.cut_counts)
Base.length(grid::SimplexGrid) = prod(grid.cut_counts)

dimensions(grid::AbstractGrid{D}) where D = D
Base.ndims(grid::AbstractGrid{D}) where D = D
dimensions(grid::AbstractGrid{D}) where {D} = D
Base.ndims(grid::AbstractGrid{D}) where {D} = D

label(grid::RectangleGrid) = "multilinear interpolation grid"
label(grid::SimplexGrid) = "simplex interpolation grid"
Expand Down Expand Up @@ -101,21 +101,21 @@ end

function ind2x!(grid::AbstractGrid, ind::Int, x::AbstractArray)
# Populates x with the value at ind.
# In-place version of ind2x.
# Example:
# rgrid = RectangleGrid([2,5],[20,50])
# x = [0,0]
# ind2x!(rgrid,4,x) # x now contains [5,50]
# In-place version of ind2x.
# Example:
# rgrid = RectangleGrid([2,5],[20,50])
# x = [0,0]
# ind2x!(rgrid,4,x) # x now contains [5,50]
# @show x # displays [5,50]
ndims = dimensions(grid)
stride = grid.cut_counts[1]
for i=2:ndims-1
for i = 2:ndims-1
stride *= grid.cut_counts[i]
end

for i=(ndims-1):-1:1
rest = rem(ind-1, stride) + 1
x[i + 1] = grid.cutPoints[i + 1][div(ind - rest, stride) + 1]
for i = (ndims-1):-1:1
rest = rem(ind - 1, stride) + 1
x[i+1] = grid.cutPoints[i+1][div(ind - rest, stride)+1]
ind = rest
stride = div(stride, grid.cut_counts[i])
end
Expand All @@ -127,8 +127,8 @@ end
# masked interpolation ignores points that are masked
function maskedInterpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVector, mask::BitArray{1})
index, weight = interpolants(grid, x)
val = 0
totalWeight = 0
val = zero(eltype(data))
totalWeight = zero(eltype(data))
for i = 1:length(index)
if mask[index[i]]
continue
Expand All @@ -139,13 +139,13 @@ function maskedInterpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVect
return val / totalWeight
end

interpolate(grid::AbstractGrid, data::Matrix, x::AbstractVector) = interpolate(grid, map(Float64, data[:]), x)
interpolate(grid::AbstractGrid, data::Matrix, x::AbstractVector) = interpolate(grid, map(eltype(data), data[:]), x)

function interpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVector)
index, weight = interpolants(grid, x)
v = 0.0
for (i,data_ind) in enumerate(index)
v += data[data_ind]*weight[i]
v = zero(eltype(data))
for (i, data_ind) in enumerate(index)
v += data[data_ind] * weight[i]
end
return v
end
Expand All @@ -157,14 +157,19 @@ function interpolants(grid::RectangleGrid, x::AbstractVector)
cut_counts = grid.cut_counts
cuts = grid.cuts


# Reset the values in index and weight:
index = @MVector(ones(Int, 2^dimensions(grid)))
index2 = @MVector(ones(Int, 2^dimensions(grid)))
weight = @MVector(zeros(eltype(x), 2^dimensions(grid)))
weight2 = @MVector(zeros(eltype(x), 2^dimensions(grid)))
index[1] = 1
index2[1] = 1
num_points = 2^dimensions(grid)
index = MVector{num_points, Int}(undef)
index2 = MVector{num_points, Int}(undef)
weight = MVector{num_points, eltype(x)}(undef)
weight2 = MVector{num_points, eltype(x)}(undef)

# Note: these values are set explicitly because we have not verified that the logic below is independent of the initial values. See discussion in PR #47. These can be removed if it can be proved that the logic is independent of the initial values.
index .= 1
index2 .= 1
weight .= zero(eltype(weight))
weight2 .= zero(eltype(weight2))

weight[1] = one(eltype(weight))
weight2[1] = one(eltype(weight2))

Expand All @@ -174,7 +179,7 @@ function interpolants(grid::RectangleGrid, x::AbstractVector)
n = 1
for d = 1:length(x)
coord = x[d]
lasti = cut_counts[d]+cut_i-1
lasti = cut_counts[d] + cut_i - 1
ii = cut_i

if coord <= cuts[ii]
Expand All @@ -188,42 +193,42 @@ function interpolants(grid::RectangleGrid, x::AbstractVector)
if cuts[ii] == coord
i_lo, i_hi = ii, ii
else
i_lo, i_hi = (ii-1), ii
i_lo, i_hi = (ii - 1), ii
end
end

# the @inbounds are needed below to prevent allocation
if i_lo == i_hi
for i = 1:l
@inbounds index[i] += (i_lo - cut_i)*subblock_size
@inbounds index[i] += (i_lo - cut_i) * subblock_size
end
else
low = (1 - (coord - cuts[i_lo])/(cuts[i_hi]-cuts[i_lo]))
low = (1 - (coord - cuts[i_lo]) / (cuts[i_hi] - cuts[i_lo]))
for i = 1:l
@inbounds index2[i ] = index[i] + (i_lo-cut_i)*subblock_size
@inbounds index2[i+l] = index[i] + (i_hi-cut_i)*subblock_size
@inbounds index2[i] = index[i] + (i_lo - cut_i) * subblock_size
@inbounds index2[i+l] = index[i] + (i_hi - cut_i) * subblock_size
end
@inbounds index[:] = index2
for i = 1:l
@inbounds weight2[i ] = weight[i]*low
@inbounds weight2[i+l] = weight[i]*(1-low)
@inbounds weight2[i] = weight[i] * low
@inbounds weight2[i+l] = weight[i] * (1 - low)
end
@inbounds weight[:] = weight2
l = l*2
n = n*2
l = l * 2
n = n * 2
end
cut_i = cut_i + cut_counts[d]
subblock_size = subblock_size*(cut_counts[d])
subblock_size = subblock_size * (cut_counts[d])
end

v = min(l,length(index))
return view(SVector(index),1:v), view(SVector(weight),1:v)
v = min(l, length(index))
return view(SVector(index), 1:v), view(SVector(weight), 1:v)
end

function interpolants(grid::SimplexGrid, x::AbstractVector)

weight = MVector{dimensions(grid)+1, eltype(x)}(undef)
index = MVector{dimensions(grid)+1, Int}(undef)
weight = MVector{dimensions(grid) + 1,eltype(x)}(undef)
index = MVector{dimensions(grid) + 1,Int}(undef)

x_p = grid.x_p # residuals
ihi = grid.ihi # indicies of cuts above point
Expand All @@ -238,17 +243,17 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
for i = 1:dimensions(grid)
# find indicies of coords if match
coord = x[i]
lasti = cut_counts[i]+cut_i-1
lasti = cut_counts[i] + cut_i - 1
ii = cut_i
# check bounds, snap to closest if out
if coord <= cuts[ii]
ihi[i] = ii
ilo[i] = ii
x_p[i] = 0.0
x_p[i] = zero(eltype(x))
elseif coord >= cuts[lasti]
ihi[i] = lasti
ilo[i] = lasti
x_p[i] = 0.0
x_p[i] = zero(eltype(x))
else
# increment through cut points if in bounds
while cuts[ii] < coord
Expand All @@ -258,10 +263,10 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
if cuts[ii] == coord
ilo[i] = ii
ihi[i] = ii
x_p[i] = 0.0
x_p[i] = zero(eltype(x))
else
# if between cuts assign lo and high indecies and translate
ilo[i] = ii-1
ilo[i] = ii - 1
ihi[i] = ii
lo = cuts[ilo[i]]
hi = cuts[ihi[i]]
Expand All @@ -272,7 +277,9 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
end

# initialize sort indecies
for i = 1:length(n_ind); n_ind[i] = i; end
for i = 1:length(n_ind)
n_ind[i] = i
end
# sort translated and scaled x values
sortperm!(n_ind, x_p, rev=true) ############################################# killer of speed
x_p = x_p[n_ind]
Expand All @@ -282,7 +289,7 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
for i = 1:(length(x_p)+1)
if i == 1
weight[i] = 1 - x_p[i]
elseif i == length(x_p)+1
elseif i == length(x_p) + 1
weight[i] = x_p[i-1]
else
weight[i] = x_p[i-1] - x_p[i]
Expand All @@ -307,7 +314,7 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
else
index[i] += (ilo[k] - 1 - ct) * siz
end
siz = siz*cut_counts[k]
siz = siz * cut_counts[k]
ct += cut_counts[k]
end
index[i] += 1
Expand All @@ -323,14 +330,14 @@ function vertices(grid::AbstractGrid)
n_dims = dimensions(grid)
mem = Array{Float64,2}(undef, n_dims, length(grid))

for idx = 1 : length(grid)
this_idx::Int = idx-1
for idx = 1:length(grid)
this_idx::Int = idx - 1

# Get the correct index into each dimension
# and populate vertex index with corresponding cut point
for j = 1 : n_dims
for j = 1:n_dims
cut_idx::Int = this_idx % grid.cut_counts[j]
this_idx = div(this_idx,grid.cut_counts[j])
this_idx = div(this_idx, grid.cut_counts[j])
mem[j, idx] = grid.cutPoints[j][cut_idx+1]
end
end
Expand All @@ -341,7 +348,7 @@ function vertices(grid::AbstractGrid)
(http://juliaarrays.github.io/StaticArrays.jl/stable/pages/
api.html#Arrays-of-static-arrays-1), and tests should catch these errors.
=#
return reshape(reinterpret(SVector{n_dims, Float64}, mem), (length(grid),))
return reshape(reinterpret(SVector{n_dims,Float64}, mem), (length(grid),))
end


Expand All @@ -353,12 +360,12 @@ using Base.Sort # for sortperm!
const DEFAULT_UNSTABLE = QuickSort

function sortperm!(x::Vector{I}, v::AbstractVector; alg::Algorithm=DEFAULT_UNSTABLE,
lt::Function=isless, by::Function=identity, rev::Bool=false, order::Ordering=Forward) where {I<:Integer}
sort!(x, alg, Perm(ord(lt,by,rev,order),v))
lt::Function=isless, by::Function=identity, rev::Bool=false, order::Ordering=Forward) where {I<:Integer}
sort!(x, alg, Perm(ord(lt, by, rev, order), v))
end

function Base.iterate(iter::RectangleGrid, state::Int64=1)
return state<=length(iter) ? (ind2x(iter, state), state+1) : nothing
return state <= length(iter) ? (ind2x(iter, state), state + 1) : nothing
end

Base.getindex(grid::RectangleGrid, key::CartesianIndex) = ind2x(grid, LinearIndices(Dims((grid.cut_counts...,)))[key])
Expand Down
Loading

0 comments on commit 07bfa02

Please sign in to comment.