diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 49ff8c97..d466c0d5 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -249,6 +249,9 @@ class DiscretizationConnectionElementGroup: def __init__(self, batches): self.batches = batches + def __repr__(self): + return f"{type(self).__name__}({self.batches})" + # }}} @@ -488,9 +491,10 @@ def _per_target_group_pick_info( if batch.from_group_index == source_group_index] # {{{ find and weed out duplicate dof pick lists + from pytools import unique - dof_pick_lists = list({tuple(batch_dof_pick_lists[bi]) - for bi in batch_indices_for_this_source_group}) + dof_pick_lists = list(unique(tuple(batch_dof_pick_lists[bi]) + for bi in batch_indices_for_this_source_group)) dof_pick_list_to_index = { p_ind: i for i, p_ind in enumerate(dof_pick_lists)} # shape: (number of pick lists, nunit_dofs_tgt) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index ee0403ab..18402bf4 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -36,7 +36,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Mapping, Sequence, Set, Union, cast +from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence, Union, cast from warnings import warn import numpy as np @@ -239,11 +239,6 @@ def __init__(self, ], bdry_grp_factory: ElementGroupFactory): """ - :arg local_bdry_conns: A :class:`dict` mapping remote part to - `local_bdry_conn`, where `local_bdry_conn` is a - :class:`~meshmode.discretization.connection.DirectDiscretizationConnection` - that performs data exchange from the volume to the faces adjacent to - part `i_remote_part`. :arg bdry_grp_factory: Group factory to use when creating the remote-to-local boundary connections """ @@ -290,8 +285,11 @@ def __enter__(self): # to know when we're done self.pending_recv_identifiers = { - (irbi.local_part_id, irbi.remote_part_id) - for irbi in self.inter_rank_bdry_info} + (irbi.local_part_id, irbi.remote_part_id): i + for i, irbi in enumerate(self.inter_rank_bdry_info)} + + assert len(self.pending_recv_identifiers) \ + == len(self.inter_rank_bdry_info) self.send_reqs = [ self._internal_mpi_comm.isend( @@ -327,14 +325,22 @@ def complete_some(self): status = MPI.Status() - # Wait for any receive - data = [self._internal_mpi_comm.recv(status=status)] - source_ranks = [status.source] - - # Complete any other available receives while we're at it - while self._internal_mpi_comm.iprobe(): - data.append(self._internal_mpi_comm.recv(status=status)) - source_ranks.append(status.source) + # Wait for all receives + # Note: This is inefficient, but ensures a deterministic order of + # boundary setup. + nrecvs = len(self.pending_recv_identifiers) + data = [None] * nrecvs + source_ranks = [None] * nrecvs + + while nrecvs > 0: + r = self._internal_mpi_comm.recv(status=status) + key = (r[1], r[0]) + loc = self.pending_recv_identifiers[key] + assert data[loc] is None + assert source_ranks[loc] is None + data[loc] = r + source_ranks[loc] = status.source + nrecvs -= 1 remote_to_local_bdry_conns = {} @@ -372,11 +378,11 @@ def complete_some(self): group_factory=self.bdry_grp_factory), remote_group_infos=remote_group_infos)) - self.pending_recv_identifiers.remove((local_part_id, remote_part_id)) + del self.pending_recv_identifiers[(local_part_id, remote_part_id)] - if not self.pending_recv_identifiers: - MPI.Request.waitall(self.send_reqs) - logger.info("bdry comm rank %d comm end", self.i_local_rank) + assert not self.pending_recv_identifiers + MPI.Request.waitall(self.send_reqs) + logger.info("bdry comm rank %d comm end", self.i_local_rank) return remote_to_local_bdry_conns @@ -432,30 +438,37 @@ def get_partition_by_pymetis(mesh, num_parts, *, connectivity="facial", **kwargs return np.array(p) -def membership_list_to_map(membership_list): +def membership_list_to_map( + membership_list: np.ndarray[Any, Any] + ) -> Mapping[Hashable, np.ndarray]: """ Convert a :class:`numpy.ndarray` that maps an index to a key into a :class:`dict` that maps a key to a set of indices (with each set of indices stored as a sorted :class:`numpy.ndarray`). """ + from pytools import unique + + # FIXME: not clear why the sorted() call is necessary here return { entry: np.where(membership_list == entry)[0] - for entry in set(membership_list)} + for entry in sorted(unique(membership_list))} # FIXME: Move somewhere else, since it's not strictly limited to distributed? -def get_connected_parts(mesh: Mesh) -> "Set[PartID]": +def get_connected_parts(mesh: Mesh) -> "Sequence[PartID]": """For a local mesh part in *mesh*, determine the set of connected parts.""" assert mesh.facial_adjacency_groups is not None - return { + from pytools import unique + + return tuple(unique( grp.part_id for fagrp_list in mesh.facial_adjacency_groups for grp in fagrp_list - if isinstance(grp, InterPartAdjacencyGroup)} + if isinstance(grp, InterPartAdjacencyGroup))) -def get_connected_partitions(mesh: Mesh) -> "Set[PartID]": +def get_connected_partitions(mesh: Mesh) -> "Sequence[PartID]": warn( "get_connected_partitions is deprecated and will stop working in June 2023. " "Use get_connected_parts instead.", DeprecationWarning, stacklevel=2) diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index 1b2a64dd..c28877da 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -25,7 +25,7 @@ from dataclasses import dataclass, replace from functools import reduce from typing import ( - Callable, Dict, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union) + Callable, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union) import numpy as np import numpy.linalg as la @@ -184,7 +184,7 @@ def _get_connected_parts( mesh: Mesh, part_id_to_part_index: Mapping[PartID, int], global_elem_to_part_elem: np.ndarray, - self_part_id: PartID) -> Set[PartID]: + self_part_id: PartID) -> Sequence[PartID]: """ Find the parts that are connected to the current part. @@ -196,10 +196,11 @@ def _get_connected_parts( :func:`_compute_global_elem_to_part_elem`` for details. :arg self_part_id: The identifier of the part currently being created. - :returns: A :class:`set` of identifiers of the neighboring parts. + :returns: A sequence of identifiers of the neighboring parts. """ self_part_index = part_id_to_part_index[self_part_id] + # This set is not used in a way that will cause nondeterminism. connected_part_indices = set() for igrp, facial_adj_list in enumerate(mesh.facial_adjacency_groups): @@ -223,10 +224,12 @@ def _get_connected_parts( elements_are_self & neighbors_are_other] + elem_base_j, 0]) - return { + result = tuple( part_id for part_id, part_index in part_id_to_part_index.items() - if part_index in connected_part_indices} + if part_index in connected_part_indices) + assert len(set(result)) == len(result) + return result def _create_self_to_self_adjacency_groups( @@ -305,7 +308,7 @@ def _create_self_to_other_adjacency_groups( self_part_id: PartID, self_mesh_groups: List[MeshElementGroup], self_mesh_group_elem_base: List[int], - connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]: + connected_parts: Sequence[PartID]) -> List[List[InterPartAdjacencyGroup]]: """ Create self-to-other adjacency groups for the partitioned mesh. diff --git a/setup.py b/setup.py index c8a1d7e8..a78d0afb 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def main(): "numpy", "modepy>=2020.2", "gmsh_interop", - "pytools>=2020.4.1", + "pytools>=2024.1.1", # 2019.1 is required for the Firedrake CIs, which use an very specific # version of Loopy. diff --git a/test/test_partition.py b/test/test_partition.py index 9a66661d..6ad5f4a9 100644 --- a/test/test_partition.py +++ b/test/test_partition.py @@ -387,6 +387,8 @@ def _test_mpi_boundary_swap(dim, order, num_groups): part_id_to_part = partition_mesh(mesh, membership_list_to_map( np.random.randint(mpi_comm.size, size=mesh.nelements))) + + assert list(part_id_to_part.keys()) == list(range(mpi_comm.size)) parts = [part_id_to_part[i] for i in range(mpi_comm.size)] local_mesh = mpi_comm.scatter(parts) @@ -424,6 +426,11 @@ def _test_mpi_boundary_swap(dim, order, num_groups): conns = bdry_setup_helper.complete_some() if not conns: break + + expected_keys = list(range(mpi_comm.size)) + expected_keys.remove(mpi_comm.rank) + assert list(conns.keys()) == expected_keys + for i_remote_part, conn in conns.items(): check_connection(actx, conn) remote_to_local_bdry_conns[i_remote_part] = conn @@ -455,7 +462,7 @@ def _test_connected_parts(mpi_comm, connected_parts): for i_remote_part in range(num_parts): if all_connected_masks[i_remote_part][mpi_comm.rank]: parts_connected_to_me.add(i_remote_part) - assert parts_connected_to_me == connected_parts + assert parts_connected_to_me == set(connected_parts) # TODO