From 0052e2157aed00e0a17d1af14e1789e6d02823d7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 24 Jun 2024 16:45:37 -0500 Subject: [PATCH 1/7] Mesh Distribution: determinism fixes --- examples/parallel-vtkhdf.py | 3 +- meshmode/discretization/connection/direct.py | 8 +- meshmode/distributed.py | 97 ++++++++------------ meshmode/mesh/processing.py | 15 +-- setup.py | 2 +- 5 files changed, 55 insertions(+), 70 deletions(-) diff --git a/examples/parallel-vtkhdf.py b/examples/parallel-vtkhdf.py index 0d93ac7f..a928176f 100644 --- a/examples/parallel-vtkhdf.py +++ b/examples/parallel-vtkhdf.py @@ -56,8 +56,7 @@ def main(*, ambient_dim: int) -> None: parts = [part_id_to_part[i] for i in range(comm.size)] local_mesh = comm.scatter(parts) else: - # Reason for type-ignore: presumed faulty type annotation in mpi4py - local_mesh = comm.scatter(None) # type: ignore[arg-type] + local_mesh = comm.scatter(None) logger.info("[%4d] distributing mesh: finished", comm.rank) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 49ff8c97..d466c0d5 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -249,6 +249,9 @@ class DiscretizationConnectionElementGroup: def __init__(self, batches): self.batches = batches + def __repr__(self): + return f"{type(self).__name__}({self.batches})" + # }}} @@ -488,9 +491,10 @@ def _per_target_group_pick_info( if batch.from_group_index == source_group_index] # {{{ find and weed out duplicate dof pick lists + from pytools import unique - dof_pick_lists = list({tuple(batch_dof_pick_lists[bi]) - for bi in batch_indices_for_this_source_group}) + dof_pick_lists = list(unique(tuple(batch_dof_pick_lists[bi]) + for bi in batch_indices_for_this_source_group)) dof_pick_list_to_index = { p_ind: i for i, p_ind in enumerate(dof_pick_lists)} # shape: (number of pick lists, nunit_dofs_tgt) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index ee0403ab..bbefd197 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -36,7 +36,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Mapping, Sequence, Set, Union, cast +from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence from warnings import warn import numpy as np @@ -231,19 +231,9 @@ class MPIBoundaryCommSetupHelper: def __init__(self, mpi_comm: "mpi4py.MPI.Intracomm", actx: ArrayContext, - inter_rank_bdry_info: Union[ - # new-timey - Sequence[InterRankBoundaryInfo], - # old-timey, for compatibility - Mapping[int, DirectDiscretizationConnection], - ], + inter_rank_bdry_info: Sequence[InterRankBoundaryInfo], bdry_grp_factory: ElementGroupFactory): """ - :arg local_bdry_conns: A :class:`dict` mapping remote part to - `local_bdry_conn`, where `local_bdry_conn` is a - :class:`~meshmode.discretization.connection.DirectDiscretizationConnection` - that performs data exchange from the volume to the faces adjacent to - part `i_remote_part`. :arg bdry_grp_factory: Group factory to use when creating the remote-to-local boundary connections """ @@ -251,30 +241,7 @@ def __init__(self, self.array_context = actx self.i_local_rank = mpi_comm.Get_rank() - # {{{ normalize inter_rank_bdry_info - - self._using_old_timey_interface = False - - if isinstance(inter_rank_bdry_info, dict): - self._using_old_timey_interface = True - warn("Using the old-timey interface of MPIBoundaryCommSetupHelper. " - "That's deprecated and will stop working in July 2022. " - "Use the currently documented interface instead.", - DeprecationWarning, stacklevel=2) - - inter_rank_bdry_info = [ - InterRankBoundaryInfo( - local_part_id=self.i_local_rank, - remote_part_id=remote_rank, - remote_rank=remote_rank, - local_boundary_connection=conn - ) - for remote_rank, conn in inter_rank_bdry_info.items()] - - # }}} - - self.inter_rank_bdry_info = cast( - Sequence[InterRankBoundaryInfo], inter_rank_bdry_info) + self.inter_rank_bdry_info = inter_rank_bdry_info self.bdry_grp_factory = bdry_grp_factory @@ -289,9 +256,13 @@ def __enter__(self): # the pickling ourselves. # to know when we're done - self.pending_recv_identifiers = { + self.pending_recv_identifiers = [ (irbi.local_part_id, irbi.remote_part_id) - for irbi in self.inter_rank_bdry_info} + for irbi in self.inter_rank_bdry_info] + + assert len(self.pending_recv_identifiers) \ + == len(self.inter_rank_bdry_info) \ + == len(set(self.pending_recv_identifiers)) self.send_reqs = [ self._internal_mpi_comm.isend( @@ -327,14 +298,20 @@ def complete_some(self): status = MPI.Status() - # Wait for any receive - data = [self._internal_mpi_comm.recv(status=status)] - source_ranks = [status.source] - - # Complete any other available receives while we're at it - while self._internal_mpi_comm.iprobe(): - data.append(self._internal_mpi_comm.recv(status=status)) - source_ranks.append(status.source) + # Wait for all receives + nrecvs = len(self.pending_recv_identifiers) + data = [None] * nrecvs + source_ranks = [None] * nrecvs + + while nrecvs > 0: + r = self._internal_mpi_comm.recv(status=status) + key = (r[1], r[0]) + loc = self.pending_recv_identifiers.index(key) + assert data[loc] is None + assert source_ranks[loc] is None + data[loc] = r + source_ranks[loc] = status.source + nrecvs -= 1 remote_to_local_bdry_conns = {} @@ -357,10 +334,7 @@ def complete_some(self): irbi = part_ids_to_irbi[local_part_id, remote_part_id] assert i_src_rank == irbi.remote_rank - if self._using_old_timey_interface: - key = remote_part_id - else: - key = (remote_part_id, local_part_id) + key = (remote_part_id, local_part_id) remote_to_local_bdry_conns[key] = ( make_partition_connection( @@ -374,9 +348,9 @@ def complete_some(self): self.pending_recv_identifiers.remove((local_part_id, remote_part_id)) - if not self.pending_recv_identifiers: - MPI.Request.waitall(self.send_reqs) - logger.info("bdry comm rank %d comm end", self.i_local_rank) + assert not self.pending_recv_identifiers + MPI.Request.waitall(self.send_reqs) + logger.info("bdry comm rank %d comm end", self.i_local_rank) return remote_to_local_bdry_conns @@ -432,30 +406,35 @@ def get_partition_by_pymetis(mesh, num_parts, *, connectivity="facial", **kwargs return np.array(p) -def membership_list_to_map(membership_list): +def membership_list_to_map( + membership_list: np.ndarray[Any, Any] + ) -> Mapping[Hashable, np.ndarray]: """ Convert a :class:`numpy.ndarray` that maps an index to a key into a :class:`dict` that maps a key to a set of indices (with each set of indices stored as a sorted :class:`numpy.ndarray`). """ + from pytools import unique return { entry: np.where(membership_list == entry)[0] - for entry in set(membership_list)} + for entry in unique(list(membership_list))} # FIXME: Move somewhere else, since it's not strictly limited to distributed? -def get_connected_parts(mesh: Mesh) -> "Set[PartID]": +def get_connected_parts(mesh: Mesh) -> "Sequence[PartID]": """For a local mesh part in *mesh*, determine the set of connected parts.""" assert mesh.facial_adjacency_groups is not None - return { + from pytools import unique + + return tuple(unique( grp.part_id for fagrp_list in mesh.facial_adjacency_groups for grp in fagrp_list - if isinstance(grp, InterPartAdjacencyGroup)} + if isinstance(grp, InterPartAdjacencyGroup))) -def get_connected_partitions(mesh: Mesh) -> "Set[PartID]": +def get_connected_partitions(mesh: Mesh) -> "Sequence[PartID]": warn( "get_connected_partitions is deprecated and will stop working in June 2023. " "Use get_connected_parts instead.", DeprecationWarning, stacklevel=2) diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index 1b2a64dd..c28877da 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -25,7 +25,7 @@ from dataclasses import dataclass, replace from functools import reduce from typing import ( - Callable, Dict, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union) + Callable, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union) import numpy as np import numpy.linalg as la @@ -184,7 +184,7 @@ def _get_connected_parts( mesh: Mesh, part_id_to_part_index: Mapping[PartID, int], global_elem_to_part_elem: np.ndarray, - self_part_id: PartID) -> Set[PartID]: + self_part_id: PartID) -> Sequence[PartID]: """ Find the parts that are connected to the current part. @@ -196,10 +196,11 @@ def _get_connected_parts( :func:`_compute_global_elem_to_part_elem`` for details. :arg self_part_id: The identifier of the part currently being created. - :returns: A :class:`set` of identifiers of the neighboring parts. + :returns: A sequence of identifiers of the neighboring parts. """ self_part_index = part_id_to_part_index[self_part_id] + # This set is not used in a way that will cause nondeterminism. connected_part_indices = set() for igrp, facial_adj_list in enumerate(mesh.facial_adjacency_groups): @@ -223,10 +224,12 @@ def _get_connected_parts( elements_are_self & neighbors_are_other] + elem_base_j, 0]) - return { + result = tuple( part_id for part_id, part_index in part_id_to_part_index.items() - if part_index in connected_part_indices} + if part_index in connected_part_indices) + assert len(set(result)) == len(result) + return result def _create_self_to_self_adjacency_groups( @@ -305,7 +308,7 @@ def _create_self_to_other_adjacency_groups( self_part_id: PartID, self_mesh_groups: List[MeshElementGroup], self_mesh_group_elem_base: List[int], - connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]: + connected_parts: Sequence[PartID]) -> List[List[InterPartAdjacencyGroup]]: """ Create self-to-other adjacency groups for the partitioned mesh. diff --git a/setup.py b/setup.py index c8a1d7e8..a78d0afb 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def main(): "numpy", "modepy>=2020.2", "gmsh_interop", - "pytools>=2020.4.1", + "pytools>=2024.1.1", # 2019.1 is required for the Firedrake CIs, which use an very specific # version of Loopy. From 4f6a81431a1a105480d6151ae6ff2007d6c29532 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 25 Jun 2024 17:46:38 -0500 Subject: [PATCH 2/7] misc fixes --- meshmode/distributed.py | 4 +++- test/test_partition.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index bbefd197..f92f4b56 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -415,9 +415,11 @@ def membership_list_to_map( stored as a sorted :class:`numpy.ndarray`). """ from pytools import unique + + # FIXME: not clear why the sorted() call is necessary here return { entry: np.where(membership_list == entry)[0] - for entry in unique(list(membership_list))} + for entry in sorted(unique(membership_list))} # FIXME: Move somewhere else, since it's not strictly limited to distributed? diff --git a/test/test_partition.py b/test/test_partition.py index 9a66661d..f96e4ff4 100644 --- a/test/test_partition.py +++ b/test/test_partition.py @@ -455,7 +455,7 @@ def _test_connected_parts(mpi_comm, connected_parts): for i_remote_part in range(num_parts): if all_connected_masks[i_remote_part][mpi_comm.rank]: parts_connected_to_me.add(i_remote_part) - assert parts_connected_to_me == connected_parts + assert parts_connected_to_me == set(connected_parts) # TODO From 4135433fdbddb30482b95c2d9cb75d8dfede1dc3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 25 Jun 2024 17:51:24 -0500 Subject: [PATCH 3/7] restore old-timey interface (needed for test) --- meshmode/distributed.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index f92f4b56..70cc5262 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -36,7 +36,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence, Union, cast from warnings import warn import numpy as np @@ -231,7 +231,12 @@ class MPIBoundaryCommSetupHelper: def __init__(self, mpi_comm: "mpi4py.MPI.Intracomm", actx: ArrayContext, - inter_rank_bdry_info: Sequence[InterRankBoundaryInfo], + inter_rank_bdry_info: Union[ + # new-timey + Sequence[InterRankBoundaryInfo], + # old-timey, for compatibility + Mapping[int, DirectDiscretizationConnection], + ], bdry_grp_factory: ElementGroupFactory): """ :arg bdry_grp_factory: Group factory to use when creating the remote-to-local @@ -241,7 +246,30 @@ def __init__(self, self.array_context = actx self.i_local_rank = mpi_comm.Get_rank() - self.inter_rank_bdry_info = inter_rank_bdry_info + # {{{ normalize inter_rank_bdry_info + + self._using_old_timey_interface = False + + if isinstance(inter_rank_bdry_info, dict): + self._using_old_timey_interface = True + warn("Using the old-timey interface of MPIBoundaryCommSetupHelper. " + "That's deprecated and will stop working in July 2022. " + "Use the currently documented interface instead.", + DeprecationWarning, stacklevel=2) + + inter_rank_bdry_info = [ + InterRankBoundaryInfo( + local_part_id=self.i_local_rank, + remote_part_id=remote_rank, + remote_rank=remote_rank, + local_boundary_connection=conn + ) + for remote_rank, conn in inter_rank_bdry_info.items()] + + # }}} + + self.inter_rank_bdry_info = cast( + Sequence[InterRankBoundaryInfo], inter_rank_bdry_info) self.bdry_grp_factory = bdry_grp_factory @@ -334,7 +362,10 @@ def complete_some(self): irbi = part_ids_to_irbi[local_part_id, remote_part_id] assert i_src_rank == irbi.remote_rank - key = (remote_part_id, local_part_id) + if self._using_old_timey_interface: + key = remote_part_id + else: + key = (remote_part_id, local_part_id) remote_to_local_bdry_conns[key] = ( make_partition_connection( From 3825b1c6911248b5cbf2a78804de1641282bc098 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 27 Jun 2024 16:59:41 -0500 Subject: [PATCH 4/7] use dict for pending_recv_identifiers --- meshmode/distributed.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index 70cc5262..39f4c6f8 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -284,13 +284,12 @@ def __enter__(self): # the pickling ourselves. # to know when we're done - self.pending_recv_identifiers = [ - (irbi.local_part_id, irbi.remote_part_id) - for irbi in self.inter_rank_bdry_info] + self.pending_recv_identifiers = { + (irbi.local_part_id, irbi.remote_part_id): i + for i, irbi in enumerate(self.inter_rank_bdry_info)} assert len(self.pending_recv_identifiers) \ - == len(self.inter_rank_bdry_info) \ - == len(set(self.pending_recv_identifiers)) + == len(self.inter_rank_bdry_info) self.send_reqs = [ self._internal_mpi_comm.isend( @@ -334,7 +333,7 @@ def complete_some(self): while nrecvs > 0: r = self._internal_mpi_comm.recv(status=status) key = (r[1], r[0]) - loc = self.pending_recv_identifiers.index(key) + loc = self.pending_recv_identifiers[key] assert data[loc] is None assert source_ranks[loc] is None data[loc] = r @@ -377,7 +376,7 @@ def complete_some(self): group_factory=self.bdry_grp_factory), remote_group_infos=remote_group_infos)) - self.pending_recv_identifiers.remove((local_part_id, remote_part_id)) + del self.pending_recv_identifiers[(local_part_id, remote_part_id)] assert not self.pending_recv_identifiers MPI.Request.waitall(self.send_reqs) From 89b773fd0f165de44f98cfba89dda8c199f71d58 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 27 Jun 2024 17:14:37 -0500 Subject: [PATCH 5/7] add some simple determinism tests --- test/test_partition.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_partition.py b/test/test_partition.py index f96e4ff4..6ad5f4a9 100644 --- a/test/test_partition.py +++ b/test/test_partition.py @@ -387,6 +387,8 @@ def _test_mpi_boundary_swap(dim, order, num_groups): part_id_to_part = partition_mesh(mesh, membership_list_to_map( np.random.randint(mpi_comm.size, size=mesh.nelements))) + + assert list(part_id_to_part.keys()) == list(range(mpi_comm.size)) parts = [part_id_to_part[i] for i in range(mpi_comm.size)] local_mesh = mpi_comm.scatter(parts) @@ -424,6 +426,11 @@ def _test_mpi_boundary_swap(dim, order, num_groups): conns = bdry_setup_helper.complete_some() if not conns: break + + expected_keys = list(range(mpi_comm.size)) + expected_keys.remove(mpi_comm.rank) + assert list(conns.keys()) == expected_keys + for i_remote_part, conn in conns.items(): check_connection(actx, conn) remote_to_local_bdry_conns[i_remote_part] = conn From c2567520a9cbd30c817dd49be4d07098db1f3a89 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 28 Jun 2024 09:53:02 -0500 Subject: [PATCH 6/7] add comment regarding deterministic order --- meshmode/distributed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index 39f4c6f8..18402bf4 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -326,6 +326,8 @@ def complete_some(self): status = MPI.Status() # Wait for all receives + # Note: This is inefficient, but ensures a deterministic order of + # boundary setup. nrecvs = len(self.pending_recv_identifiers) data = [None] * nrecvs source_ranks = [None] * nrecvs From dd8cdab1b35041cb842ee7b9a22fa083af0626d5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 28 Jun 2024 18:26:52 -0500 Subject: [PATCH 7/7] reset example type ignore --- examples/parallel-vtkhdf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/parallel-vtkhdf.py b/examples/parallel-vtkhdf.py index a928176f..0d93ac7f 100644 --- a/examples/parallel-vtkhdf.py +++ b/examples/parallel-vtkhdf.py @@ -56,7 +56,8 @@ def main(*, ambient_dim: int) -> None: parts = [part_id_to_part[i] for i in range(comm.size)] local_mesh = comm.scatter(parts) else: - local_mesh = comm.scatter(None) + # Reason for type-ignore: presumed faulty type annotation in mpi4py + local_mesh = comm.scatter(None) # type: ignore[arg-type] logger.info("[%4d] distributing mesh: finished", comm.rank)