Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to improve type stability #47

Merged
merged 8 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
julia = "1"
StaticArrays = "0.9, 0.10, 0.11, 0.12, 1"
julia = "1"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
125 changes: 66 additions & 59 deletions src/GridInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ mutable struct RectangleGrid{D} <: AbstractGrid{D}
cut_counts::Vector{Int}
cuts::Vector{Float64}

function RectangleGrid{D}(cutPoints...) where D
function RectangleGrid{D}(cutPoints...) where {D}
cut_counts = Int[length(cutPoints[i]) for i = 1:length(cutPoints)]
cuts = vcat(cutPoints...)
myCutPoints = Array{Vector{Float64}}(undef, length(cutPoints))
numDims = length(cutPoints)
@assert numDims == D
for i = 1:numDims
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)", i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
Expand All @@ -43,15 +43,15 @@ mutable struct SimplexGrid{D} <: AbstractGrid{D}
ilo::Vector{Int} # indices of cuts below point
n_ind::Vector{Int}

function SimplexGrid{D}(cutPoints...) where D
function SimplexGrid{D}(cutPoints...) where {D}
cut_counts = Int[length(cutPoints[i]) for i = 1:length(cutPoints)]
cuts = vcat(cutPoints...)
myCutPoints = Array{Vector{Float64}}(undef, length(cutPoints))
numDims = length(cutPoints)
@assert numDims == D
for i = 1:numDims
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)", i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
Expand All @@ -72,8 +72,8 @@ Base.length(grid::RectangleGrid) = prod(grid.cut_counts)
Base.size(grid::RectangleGrid) = Tuple(grid.cut_counts)
Base.length(grid::SimplexGrid) = prod(grid.cut_counts)

dimensions(grid::AbstractGrid{D}) where D = D
Base.ndims(grid::AbstractGrid{D}) where D = D
dimensions(grid::AbstractGrid{D}) where {D} = D
Base.ndims(grid::AbstractGrid{D}) where {D} = D

label(grid::RectangleGrid) = "multilinear interpolation grid"
label(grid::SimplexGrid) = "simplex interpolation grid"
Expand Down Expand Up @@ -101,21 +101,21 @@ end

function ind2x!(grid::AbstractGrid, ind::Int, x::AbstractArray)
# Populates x with the value at ind.
# In-place version of ind2x.
# Example:
# rgrid = RectangleGrid([2,5],[20,50])
# x = [0,0]
# ind2x!(rgrid,4,x) # x now contains [5,50]
# In-place version of ind2x.
# Example:
# rgrid = RectangleGrid([2,5],[20,50])
# x = [0,0]
# ind2x!(rgrid,4,x) # x now contains [5,50]
# @show x # displays [5,50]
ndims = dimensions(grid)
stride = grid.cut_counts[1]
for i=2:ndims-1
for i = 2:ndims-1
stride *= grid.cut_counts[i]
end

for i=(ndims-1):-1:1
rest = rem(ind-1, stride) + 1
x[i + 1] = grid.cutPoints[i + 1][div(ind - rest, stride) + 1]
for i = (ndims-1):-1:1
rest = rem(ind - 1, stride) + 1
x[i+1] = grid.cutPoints[i+1][div(ind - rest, stride)+1]
ind = rest
stride = div(stride, grid.cut_counts[i])
end
Expand All @@ -127,8 +127,8 @@ end
# masked interpolation ignores points that are masked
function maskedInterpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVector, mask::BitArray{1})
index, weight = interpolants(grid, x)
val = 0
totalWeight = 0
val = zero(eltype(data))
totalWeight = zero(eltype(data))
for i = 1:length(index)
if mask[index[i]]
continue
Expand All @@ -139,13 +139,13 @@ function maskedInterpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVect
return val / totalWeight
end

interpolate(grid::AbstractGrid, data::Matrix, x::AbstractVector) = interpolate(grid, map(Float64, data[:]), x)
interpolate(grid::AbstractGrid, data::Matrix, x::AbstractVector) = interpolate(grid, map(eltype(data), data[:]), x)

function interpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVector)
index, weight = interpolants(grid, x)
v = 0.0
for (i,data_ind) in enumerate(index)
v += data[data_ind]*weight[i]
v = zero(eltype(data))
for (i, data_ind) in enumerate(index)
v += data[data_ind] * weight[i]
end
return v
end
Expand All @@ -157,14 +157,19 @@ function interpolants(grid::RectangleGrid, x::AbstractVector)
cut_counts = grid.cut_counts
cuts = grid.cuts


# Reset the values in index and weight:
index = @MVector(ones(Int, 2^dimensions(grid)))
index2 = @MVector(ones(Int, 2^dimensions(grid)))
weight = @MVector(zeros(eltype(x), 2^dimensions(grid)))
weight2 = @MVector(zeros(eltype(x), 2^dimensions(grid)))
index[1] = 1
index2[1] = 1
num_points = 2^dimensions(grid)
index = MVector{num_points, Int}(undef)
index2 = MVector{num_points, Int}(undef)
weight = MVector{num_points, eltype(x)}(undef)
weight2 = MVector{num_points, eltype(x)}(undef)
Comment on lines +162 to +165
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit concerned about this change. How can we be completely sure that it does not change the code logic to initialize these to undef rather than ones or zeros? (note that the tests may pass most of the time even if the code relies on zeros since memory is often initialized to zeros or small numbers.)

I think we should either leave this as is (with ones or zeros) or carefully go through the code below looking at all of the assignments to these variables and write out a brief proof that the initialization does not matter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Zeroing memory typically does not have an impact on runtime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original goal for that change was to reduce allocations, because the ones() and zeros() calls would theoretically be allocating arrays, although I guess the compiler could be factoring that out. From looking at the code I don't believe the arrays need to be explicitly initialized, but just in case I've added it back in, just using broadcasting instead of allocating new arrays for the initialization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ctessum ! Yeah, If I remember correctly, we verified that @MVector zeros(...) etc. does not allocate as long as the MVector does not leave the scope of the function. I added an explanatory note.

If we wanted to avoid this and be explicit about proving that the initialization doesn't matter, we could use PushVectors backed with MVectors, but that might require some work.


# Note: these values are set explicitly because we have not verified that the logic below is independent of the initial values. See discussion in PR #47. These can be removed if it can be proved that the logic is independent of the initial values.
index .= 1
index2 .= 1
weight .= zero(eltype(weight))
weight2 .= zero(eltype(weight2))

weight[1] = one(eltype(weight))
weight2[1] = one(eltype(weight2))

Expand All @@ -174,7 +179,7 @@ function interpolants(grid::RectangleGrid, x::AbstractVector)
n = 1
for d = 1:length(x)
coord = x[d]
lasti = cut_counts[d]+cut_i-1
lasti = cut_counts[d] + cut_i - 1
ii = cut_i

if coord <= cuts[ii]
Expand All @@ -188,42 +193,42 @@ function interpolants(grid::RectangleGrid, x::AbstractVector)
if cuts[ii] == coord
i_lo, i_hi = ii, ii
else
i_lo, i_hi = (ii-1), ii
i_lo, i_hi = (ii - 1), ii
end
end

# the @inbounds are needed below to prevent allocation
if i_lo == i_hi
for i = 1:l
@inbounds index[i] += (i_lo - cut_i)*subblock_size
@inbounds index[i] += (i_lo - cut_i) * subblock_size
end
else
low = (1 - (coord - cuts[i_lo])/(cuts[i_hi]-cuts[i_lo]))
low = (1 - (coord - cuts[i_lo]) / (cuts[i_hi] - cuts[i_lo]))
for i = 1:l
@inbounds index2[i ] = index[i] + (i_lo-cut_i)*subblock_size
@inbounds index2[i+l] = index[i] + (i_hi-cut_i)*subblock_size
@inbounds index2[i] = index[i] + (i_lo - cut_i) * subblock_size
@inbounds index2[i+l] = index[i] + (i_hi - cut_i) * subblock_size
end
@inbounds index[:] = index2
for i = 1:l
@inbounds weight2[i ] = weight[i]*low
@inbounds weight2[i+l] = weight[i]*(1-low)
@inbounds weight2[i] = weight[i] * low
@inbounds weight2[i+l] = weight[i] * (1 - low)
end
@inbounds weight[:] = weight2
l = l*2
n = n*2
l = l * 2
n = n * 2
end
cut_i = cut_i + cut_counts[d]
subblock_size = subblock_size*(cut_counts[d])
subblock_size = subblock_size * (cut_counts[d])
end

v = min(l,length(index))
return view(SVector(index),1:v), view(SVector(weight),1:v)
v = min(l, length(index))
return view(SVector(index), 1:v), view(SVector(weight), 1:v)
end

function interpolants(grid::SimplexGrid, x::AbstractVector)

weight = MVector{dimensions(grid)+1, eltype(x)}(undef)
index = MVector{dimensions(grid)+1, Int}(undef)
weight = MVector{dimensions(grid) + 1,eltype(x)}(undef)
index = MVector{dimensions(grid) + 1,Int}(undef)

x_p = grid.x_p # residuals
ihi = grid.ihi # indicies of cuts above point
Expand All @@ -238,17 +243,17 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
for i = 1:dimensions(grid)
# find indicies of coords if match
coord = x[i]
lasti = cut_counts[i]+cut_i-1
lasti = cut_counts[i] + cut_i - 1
ii = cut_i
# check bounds, snap to closest if out
if coord <= cuts[ii]
ihi[i] = ii
ilo[i] = ii
x_p[i] = 0.0
x_p[i] = zero(eltype(x))
elseif coord >= cuts[lasti]
ihi[i] = lasti
ilo[i] = lasti
x_p[i] = 0.0
x_p[i] = zero(eltype(x))
else
# increment through cut points if in bounds
while cuts[ii] < coord
Expand All @@ -258,10 +263,10 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
if cuts[ii] == coord
ilo[i] = ii
ihi[i] = ii
x_p[i] = 0.0
x_p[i] = zero(eltype(x))
else
# if between cuts assign lo and high indecies and translate
ilo[i] = ii-1
ilo[i] = ii - 1
ihi[i] = ii
lo = cuts[ilo[i]]
hi = cuts[ihi[i]]
Expand All @@ -272,7 +277,9 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
end

# initialize sort indecies
for i = 1:length(n_ind); n_ind[i] = i; end
for i = 1:length(n_ind)
n_ind[i] = i
end
# sort translated and scaled x values
sortperm!(n_ind, x_p, rev=true) ############################################# killer of speed
x_p = x_p[n_ind]
Expand All @@ -282,7 +289,7 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
for i = 1:(length(x_p)+1)
if i == 1
weight[i] = 1 - x_p[i]
elseif i == length(x_p)+1
elseif i == length(x_p) + 1
weight[i] = x_p[i-1]
else
weight[i] = x_p[i-1] - x_p[i]
Expand All @@ -307,7 +314,7 @@ function interpolants(grid::SimplexGrid, x::AbstractVector)
else
index[i] += (ilo[k] - 1 - ct) * siz
end
siz = siz*cut_counts[k]
siz = siz * cut_counts[k]
ct += cut_counts[k]
end
index[i] += 1
Expand All @@ -323,14 +330,14 @@ function vertices(grid::AbstractGrid)
n_dims = dimensions(grid)
mem = Array{Float64,2}(undef, n_dims, length(grid))

for idx = 1 : length(grid)
this_idx::Int = idx-1
for idx = 1:length(grid)
this_idx::Int = idx - 1

# Get the correct index into each dimension
# and populate vertex index with corresponding cut point
for j = 1 : n_dims
for j = 1:n_dims
cut_idx::Int = this_idx % grid.cut_counts[j]
this_idx = div(this_idx,grid.cut_counts[j])
this_idx = div(this_idx, grid.cut_counts[j])
mem[j, idx] = grid.cutPoints[j][cut_idx+1]
end
end
Expand All @@ -341,7 +348,7 @@ function vertices(grid::AbstractGrid)
(http://juliaarrays.github.io/StaticArrays.jl/stable/pages/
api.html#Arrays-of-static-arrays-1), and tests should catch these errors.
=#
return reshape(reinterpret(SVector{n_dims, Float64}, mem), (length(grid),))
return reshape(reinterpret(SVector{n_dims,Float64}, mem), (length(grid),))
end


Expand All @@ -353,12 +360,12 @@ using Base.Sort # for sortperm!
const DEFAULT_UNSTABLE = QuickSort

function sortperm!(x::Vector{I}, v::AbstractVector; alg::Algorithm=DEFAULT_UNSTABLE,
lt::Function=isless, by::Function=identity, rev::Bool=false, order::Ordering=Forward) where {I<:Integer}
sort!(x, alg, Perm(ord(lt,by,rev,order),v))
lt::Function=isless, by::Function=identity, rev::Bool=false, order::Ordering=Forward) where {I<:Integer}
sort!(x, alg, Perm(ord(lt, by, rev, order), v))
end

function Base.iterate(iter::RectangleGrid, state::Int64=1)
return state<=length(iter) ? (ind2x(iter, state), state+1) : nothing
return state <= length(iter) ? (ind2x(iter, state), state + 1) : nothing
end

Base.getindex(grid::RectangleGrid, key::CartesianIndex) = ind2x(grid, LinearIndices(Dims((grid.cut_counts...,)))[key])
Expand Down
Loading
Loading