Skip to content

Commit

Permalink
Minimize amount of copies/tmps in aarch64_mpi.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy committed Sep 20, 2024
1 parent 1547839 commit 5118a22
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/workarounds/aarch64_mpi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,29 @@ end
# Vec3{T} must be cast to Vector{T} before MPI reduction
function mpi_sum!(arr::Vector{Vec3{T}}, comm::MPI.Comm) where{T}
n = length(arr)
new_arr = Vector{T}([])
new_arr = zeros(T, 3n)
for i in 1:n
append!(new_arr, arr[i][1:3])
new_arr[3(i-1)+1:3(i-1)+3] = @view arr[i][1:3]
end
mpi_sum!(new_arr, comm)
for i in 1:n
arr[i] = Vec3{T}(new_arr[3(i-1)+1:3(i-1)+3])
arr[i] = Vec3{T}(@view new_arr[3(i-1)+1:3(i-1)+3])
end
arr

Check warning on line 31 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L21-L31

Added lines #L21 - L31 were not covered by tests
end

# ForwardDiff.Dual{T, U, V} and arrays of it must be cast to Vector{U} as well
# utility function to cast a Dual type to an array containing a value and the partial diffs
function dual_array(dual::ForwardDiff.Dual{T, U, V}) where{T, U, V}
dual_array = [ForwardDiff.value(dual)]
append!(dual_array, collect(ForwardDiff.partials(dual)))
dual_array = zeros(U, ForwardDiff.npartials(dual)+1)
dual_array[1] = ForwardDiff.value(dual)
dual_array[2:end] = @view dual.partials[1:end]
dual_array

Check warning on line 40 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L36-L40

Added lines #L36 - L40 were not covered by tests
end

# utility function that casts back an array to a Dual type, based on a template Dual
function new_dual(dual_array, template)
DualType = typeof(template)
PartialsType = typeof(ForwardDiff.partials(template))
DualType(dual_array[1], PartialsType(Tuple(dual_array[2:end])))
function new_dual(dual_array, template::ForwardDiff.Dual{T, U, V}) where{T, U, V}
ForwardDiff.Dual{T}(dual_array[1], Tuple(@view dual_array[2:end]))

Check warning on line 45 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end

# MPI reductions of single ForwardDiff.Dual types
Expand Down Expand Up @@ -83,8 +82,9 @@ function mpi_sum!(dual::Array{ForwardDiff.Dual{T, U, V}, N}, comm::MPI.Comm) whe
mpi_sum!(array, comm)
offset = 0
for i in 1:length(dual)
dual[i] = new_dual(array[offset+1:offset+lengths[i]], dual[i])
view = @view array[offset+1:offset+lengths[i]]
dual[i] = new_dual(view, dual[i])
offset += lengths[i]
end
dual

Check warning on line 89 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L74-L89

Added lines #L74 - L89 were not covered by tests
end
end

0 comments on commit 5118a22

Please sign in to comment.