Skip to content

Commit

Permalink
Mesh Distribution: determinism fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jun 25, 2024
1 parent 6bb638f commit 0052e21
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 70 deletions.
3 changes: 1 addition & 2 deletions examples/parallel-vtkhdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def main(*, ambient_dim: int) -> None:
parts = [part_id_to_part[i] for i in range(comm.size)]
local_mesh = comm.scatter(parts)
else:
# Reason for type-ignore: presumed faulty type annotation in mpi4py
local_mesh = comm.scatter(None) # type: ignore[arg-type]
local_mesh = comm.scatter(None)

logger.info("[%4d] distributing mesh: finished", comm.rank)

Expand Down
8 changes: 6 additions & 2 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ class DiscretizationConnectionElementGroup:
def __init__(self, batches):
self.batches = batches

def __repr__(self):
return f"{type(self).__name__}({self.batches})"

# }}}


Expand Down Expand Up @@ -488,9 +491,10 @@ def _per_target_group_pick_info(
if batch.from_group_index == source_group_index]

# {{{ find and weed out duplicate dof pick lists
from pytools import unique

dof_pick_lists = list({tuple(batch_dof_pick_lists[bi])
for bi in batch_indices_for_this_source_group})
dof_pick_lists = list(unique(tuple(batch_dof_pick_lists[bi])
for bi in batch_indices_for_this_source_group))
dof_pick_list_to_index = {
p_ind: i for i, p_ind in enumerate(dof_pick_lists)}
# shape: (number of pick lists, nunit_dofs_tgt)
Expand Down
97 changes: 38 additions & 59 deletions meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Mapping, Sequence, Set, Union, cast
from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -231,50 +231,17 @@ class MPIBoundaryCommSetupHelper:
def __init__(self,
mpi_comm: "mpi4py.MPI.Intracomm",
actx: ArrayContext,
inter_rank_bdry_info: Union[
# new-timey
Sequence[InterRankBoundaryInfo],
# old-timey, for compatibility
Mapping[int, DirectDiscretizationConnection],
],
inter_rank_bdry_info: Sequence[InterRankBoundaryInfo],
bdry_grp_factory: ElementGroupFactory):
"""
:arg local_bdry_conns: A :class:`dict` mapping remote part to
`local_bdry_conn`, where `local_bdry_conn` is a
:class:`~meshmode.discretization.connection.DirectDiscretizationConnection`
that performs data exchange from the volume to the faces adjacent to
part `i_remote_part`.
:arg bdry_grp_factory: Group factory to use when creating the remote-to-local
boundary connections
"""
self.mpi_comm = mpi_comm
self.array_context = actx
self.i_local_rank = mpi_comm.Get_rank()

# {{{ normalize inter_rank_bdry_info

self._using_old_timey_interface = False

if isinstance(inter_rank_bdry_info, dict):
self._using_old_timey_interface = True
warn("Using the old-timey interface of MPIBoundaryCommSetupHelper. "
"That's deprecated and will stop working in July 2022. "
"Use the currently documented interface instead.",
DeprecationWarning, stacklevel=2)

inter_rank_bdry_info = [
InterRankBoundaryInfo(
local_part_id=self.i_local_rank,
remote_part_id=remote_rank,
remote_rank=remote_rank,
local_boundary_connection=conn
)
for remote_rank, conn in inter_rank_bdry_info.items()]

# }}}

self.inter_rank_bdry_info = cast(
Sequence[InterRankBoundaryInfo], inter_rank_bdry_info)
self.inter_rank_bdry_info = inter_rank_bdry_info

self.bdry_grp_factory = bdry_grp_factory

Expand All @@ -289,9 +256,13 @@ def __enter__(self):
# the pickling ourselves.

# to know when we're done
self.pending_recv_identifiers = {
self.pending_recv_identifiers = [
(irbi.local_part_id, irbi.remote_part_id)
for irbi in self.inter_rank_bdry_info}
for irbi in self.inter_rank_bdry_info]

assert len(self.pending_recv_identifiers) \
== len(self.inter_rank_bdry_info) \
== len(set(self.pending_recv_identifiers))

self.send_reqs = [
self._internal_mpi_comm.isend(
Expand Down Expand Up @@ -327,14 +298,20 @@ def complete_some(self):

status = MPI.Status()

# Wait for any receive
data = [self._internal_mpi_comm.recv(status=status)]
source_ranks = [status.source]

# Complete any other available receives while we're at it
while self._internal_mpi_comm.iprobe():
data.append(self._internal_mpi_comm.recv(status=status))
source_ranks.append(status.source)
# Wait for all receives
nrecvs = len(self.pending_recv_identifiers)
data = [None] * nrecvs
source_ranks = [None] * nrecvs

while nrecvs > 0:
r = self._internal_mpi_comm.recv(status=status)
key = (r[1], r[0])
loc = self.pending_recv_identifiers.index(key)
assert data[loc] is None
assert source_ranks[loc] is None
data[loc] = r
source_ranks[loc] = status.source
nrecvs -= 1

remote_to_local_bdry_conns = {}

Expand All @@ -357,10 +334,7 @@ def complete_some(self):
irbi = part_ids_to_irbi[local_part_id, remote_part_id]
assert i_src_rank == irbi.remote_rank

if self._using_old_timey_interface:
key = remote_part_id
else:
key = (remote_part_id, local_part_id)
key = (remote_part_id, local_part_id)

remote_to_local_bdry_conns[key] = (
make_partition_connection(
Expand All @@ -374,9 +348,9 @@ def complete_some(self):

self.pending_recv_identifiers.remove((local_part_id, remote_part_id))

if not self.pending_recv_identifiers:
MPI.Request.waitall(self.send_reqs)
logger.info("bdry comm rank %d comm end", self.i_local_rank)
assert not self.pending_recv_identifiers
MPI.Request.waitall(self.send_reqs)
logger.info("bdry comm rank %d comm end", self.i_local_rank)

return remote_to_local_bdry_conns

Expand Down Expand Up @@ -432,30 +406,35 @@ def get_partition_by_pymetis(mesh, num_parts, *, connectivity="facial", **kwargs
return np.array(p)


def membership_list_to_map(membership_list):
def membership_list_to_map(
membership_list: np.ndarray[Any, Any]
) -> Mapping[Hashable, np.ndarray]:
"""
Convert a :class:`numpy.ndarray` that maps an index to a key into a
:class:`dict` that maps a key to a set of indices (with each set of indices
stored as a sorted :class:`numpy.ndarray`).
"""
from pytools import unique
return {
entry: np.where(membership_list == entry)[0]
for entry in set(membership_list)}
for entry in unique(list(membership_list))}


# FIXME: Move somewhere else, since it's not strictly limited to distributed?
def get_connected_parts(mesh: Mesh) -> "Set[PartID]":
def get_connected_parts(mesh: Mesh) -> "Sequence[PartID]":
"""For a local mesh part in *mesh*, determine the set of connected parts."""
assert mesh.facial_adjacency_groups is not None

return {
from pytools import unique

return tuple(unique(
grp.part_id
for fagrp_list in mesh.facial_adjacency_groups
for grp in fagrp_list
if isinstance(grp, InterPartAdjacencyGroup)}
if isinstance(grp, InterPartAdjacencyGroup)))


def get_connected_partitions(mesh: Mesh) -> "Set[PartID]":
def get_connected_partitions(mesh: Mesh) -> "Sequence[PartID]":
warn(
"get_connected_partitions is deprecated and will stop working in June 2023. "
"Use get_connected_parts instead.", DeprecationWarning, stacklevel=2)
Expand Down
15 changes: 9 additions & 6 deletions meshmode/mesh/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dataclasses import dataclass, replace
from functools import reduce
from typing import (
Callable, Dict, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union)
Callable, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union)

import numpy as np
import numpy.linalg as la
Expand Down Expand Up @@ -184,7 +184,7 @@ def _get_connected_parts(
mesh: Mesh,
part_id_to_part_index: Mapping[PartID, int],
global_elem_to_part_elem: np.ndarray,
self_part_id: PartID) -> Set[PartID]:
self_part_id: PartID) -> Sequence[PartID]:
"""
Find the parts that are connected to the current part.
Expand All @@ -196,10 +196,11 @@ def _get_connected_parts(
:func:`_compute_global_elem_to_part_elem`` for details.
:arg self_part_id: The identifier of the part currently being created.
:returns: A :class:`set` of identifiers of the neighboring parts.
:returns: A sequence of identifiers of the neighboring parts.
"""
self_part_index = part_id_to_part_index[self_part_id]

# This set is not used in a way that will cause nondeterminism.
connected_part_indices = set()

for igrp, facial_adj_list in enumerate(mesh.facial_adjacency_groups):
Expand All @@ -223,10 +224,12 @@ def _get_connected_parts(
elements_are_self & neighbors_are_other]
+ elem_base_j, 0])

return {
result = tuple(
part_id
for part_id, part_index in part_id_to_part_index.items()
if part_index in connected_part_indices}
if part_index in connected_part_indices)
assert len(set(result)) == len(result)
return result


def _create_self_to_self_adjacency_groups(
Expand Down Expand Up @@ -305,7 +308,7 @@ def _create_self_to_other_adjacency_groups(
self_part_id: PartID,
self_mesh_groups: List[MeshElementGroup],
self_mesh_group_elem_base: List[int],
connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]:
connected_parts: Sequence[PartID]) -> List[List[InterPartAdjacencyGroup]]:
"""
Create self-to-other adjacency groups for the partitioned mesh.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main():
"numpy",
"modepy>=2020.2",
"gmsh_interop",
"pytools>=2020.4.1",
"pytools>=2024.1.1",

# 2019.1 is required for the Firedrake CIs, which use an very specific
# version of Loopy.
Expand Down

0 comments on commit 0052e21

Please sign in to comment.