Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPI workarounds for aarch64 #999

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,6 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Real, fft_size::Tuple{Int, Int, I
n_kpt = length(kcoords_global)
n_procs = mpi_nprocs(comm_kpts)

# Custom reduction operators for MPI are currently not working on aarch64, so
# fallbacks are defined in common/mpi.jl. For them to work, there cannot be more
# than 1 MPI process.
if Base.Sys.ARCH == :aarch64 && n_procs > 1
error("MPI not supported on aarch64 " *
"(see https://github.com/JuliaParallel/MPI.jl/issues/404)")
end

if n_procs > n_kpt
# XXX Supporting more processors than kpoints would require
# fixing a bunch of "reducing over empty collections" errors
Expand Down
30 changes: 14 additions & 16 deletions src/common/mpi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,19 @@
mpi_nprocs(comm=MPI.COMM_WORLD) = (MPI.Init(); MPI.Comm_size(comm))
mpi_master(comm=MPI.COMM_WORLD) = (MPI.Init(); MPI.Comm_rank(comm) == 0)

mpi_sum( arr, comm::MPI.Comm) = MPI.Allreduce( arr, +, comm)
mpi_sum!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, +, comm)
mpi_min( arr, comm::MPI.Comm) = MPI.Allreduce( arr, min, comm)
mpi_min!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, min, comm)

Check warning on line 13 in src/common/mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/common/mpi.jl#L13

Added line #L13 was not covered by tests
mpi_max( arr, comm::MPI.Comm) = MPI.Allreduce( arr, max, comm)
mpi_max!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, max, comm)

Check warning on line 15 in src/common/mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/common/mpi.jl#L15

Added line #L15 was not covered by tests
mpi_mean( arr, comm::MPI.Comm) = mpi_sum(arr, comm) ./ mpi_nprocs(comm)
mpi_mean!(arr, comm::MPI.Comm) = (mpi_sum!(arr, comm); arr ./= mpi_nprocs(comm))

@static if Base.Sys.ARCH == :aarch64
# Custom reduction operators are not supported on aarch64 (see
# https://github.com/JuliaParallel/MPI.jl/issues/404), so we define fallback no-op
# mpi_* functions to get things working while waiting for an upstream solution.
for fun in (:mpi_sum, :mpi_sum!, :mpi_min, :mpi_min!, :mpi_max, :mpi_max!,
:mpi_mean, :mpi_mean!)
@eval $fun(arr, ::MPI.Comm) = arr
end
else
mpi_sum( arr, comm::MPI.Comm) = MPI.Allreduce( arr, +, comm)
mpi_sum!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, +, comm)
mpi_min( arr, comm::MPI.Comm) = MPI.Allreduce( arr, min, comm)
mpi_min!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, min, comm)
mpi_max( arr, comm::MPI.Comm) = MPI.Allreduce( arr, max, comm)
mpi_max!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, max, comm)
mpi_mean( arr, comm::MPI.Comm) = mpi_sum(arr, comm) ./ mpi_nprocs(comm)
mpi_mean!(arr, comm::MPI.Comm) = (mpi_sum!(arr, comm); arr ./= mpi_nprocs(comm))
end
# https://github.com/JuliaParallel/MPI.jl/issues/404). We define
# temporary workarounds in order to be able to run MPI on aarch64
# anyways. These should be removed as soon as there is an upstream fix
include("../workarounds/aarch64_mpi.jl")
end
90 changes: 90 additions & 0 deletions src/workarounds/aarch64_mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
MPI reduction opertations with custom types (i.e. anything that has not a MPI datatype equivalent)
are not available on aarch64. These are temprorary workarounds, where variables with custom types
are broken down to standard types before communication, and recast to the initial types after.
This file was created by fixing all MPI errors encountered by running the tests on an ARM machine:
all sensible MPI reduction routines are implemented for each custom type causing an error.
"""

# Julia's Bool type has no direct equivalent MPI datatype => need integer conversion
function mpi_min(bool::Bool, comm::MPI.Comm)
int = Int(bool)
Bool(mpi_min(int, comm))

Check warning on line 12 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L10-L12

Added lines #L10 - L12 were not covered by tests
end

function mpi_max(bool::Bool, comm::MPI.Comm)
int = Int(bool)
Bool(mpi_max(int, comm))

Check warning on line 17 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L15-L17

Added lines #L15 - L17 were not covered by tests
end

# Vec3{T} must be cast to Vector{T} before MPI reduction
function mpi_sum!(arr::Vector{Vec3{T}}, comm::MPI.Comm) where{T}
n = length(arr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this not be solved by a reinterpret(reshape, ... )? Then there is no copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one is tricky to get without copies, because SVector are immutable. So even if I use something like new_aarr = reshape(reinterpret(T, arr), 3, :), I still could not call mpi_sum!(). And assuming I call mpi_sum() instead, I also end up with a copy.

new_arr = zeros(T, 3n)
for i in 1:n
new_arr[3(i-1)+1:3(i-1)+3] = @view arr[i][1:3]
end
mpi_sum!(new_arr, comm)
for i in 1:n
arr[i] = Vec3{T}(@view new_arr[3(i-1)+1:3(i-1)+3])
end
arr

Check warning on line 31 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L21-L31

Added lines #L21 - L31 were not covered by tests
end

# ForwardDiff.Dual{T, U, V} and arrays of it must be cast to Vector{U} as well
# utility function to cast a Dual type to an array containing a value and the partial diffs
function dual_array(dual::ForwardDiff.Dual{T, U, V}) where{T, U, V}
dual_array = zeros(U, ForwardDiff.npartials(dual)+1)
dual_array[1] = ForwardDiff.value(dual)
dual_array[2:end] = @view dual.partials[1:end]
dual_array

Check warning on line 40 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L36-L40

Added lines #L36 - L40 were not covered by tests
end

# utility function that casts back an array to a Dual type, based on a template Dual
function new_dual(dual_array, template::ForwardDiff.Dual{T, U, V}) where{T, U, V}
ForwardDiff.Dual{T}(dual_array[1], Tuple(@view dual_array[2:end]))

Check warning on line 45 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end

# MPI reductions of single ForwardDiff.Dual types
function mpi_sum(dual::ForwardDiff.Dual{T, U, V}, comm::MPI.Comm) where{T, U, V}
arr = dual_array(dual)
mpi_sum!(arr, comm)
new_dual(arr, dual)

Check warning on line 52 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L49-L52

Added lines #L49 - L52 were not covered by tests
end

function mpi_min(dual::ForwardDiff.Dual{T, U, V}, comm::MPI.Comm) where{T, U, V}
arr = dual_array(dual)
mpi_min!(arr, comm)
new_dual(arr, dual)

Check warning on line 58 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L55-L58

Added lines #L55 - L58 were not covered by tests
end

function mpi_max(dual::ForwardDiff.Dual{T, U, V}, comm::MPI.Comm) where{T, U, V}
arr = dual_array(dual)
mpi_max!(arr, comm)
new_dual(arr, dual)

Check warning on line 64 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L61-L64

Added lines #L61 - L64 were not covered by tests
end

function mpi_mean(dual::ForwardDiff.Dual{T, U, V}, comm::MPI.Comm) where{T, U, V}
arr = dual_array(dual)
mpi_mean!(arr, comm)
new_dual(arr, dual)

Check warning on line 70 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L67-L70

Added lines #L67 - L70 were not covered by tests
end

# MPI reductions of arrays of ForwardDiff.Dual types
function mpi_sum!(dual::Array{ForwardDiff.Dual{T, U, V}, N}, comm::MPI.Comm) where{T, U, V, N}
array = Vector{U}([])
lengths = []
for i in 1:length(dual)
tmp = dual_array(dual[i])
append!(array, tmp)
append!(lengths, length(tmp))
end
mpi_sum!(array, comm)
offset = 0
for i in 1:length(dual)
view = @view array[offset+1:offset+lengths[i]]
dual[i] = new_dual(view, dual[i])
offset += lengths[i]
end
dual

Check warning on line 89 in src/workarounds/aarch64_mpi.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/aarch64_mpi.jl#L74-L89

Added lines #L74 - L89 were not covered by tests
end
Loading