Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mesh Distribution: determinism fixes #416

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ class DiscretizationConnectionElementGroup:
def __init__(self, batches):
self.batches = batches

def __repr__(self):
return f"{type(self).__name__}({self.batches})"

# }}}


Expand Down Expand Up @@ -488,9 +491,10 @@ def _per_target_group_pick_info(
if batch.from_group_index == source_group_index]

# {{{ find and weed out duplicate dof pick lists
from pytools import unique

dof_pick_lists = list({tuple(batch_dof_pick_lists[bi])
for bi in batch_indices_for_this_source_group})
dof_pick_lists = list(unique(tuple(batch_dof_pick_lists[bi])
for bi in batch_indices_for_this_source_group))
dof_pick_list_to_index = {
p_ind: i for i, p_ind in enumerate(dof_pick_lists)}
# shape: (number of pick lists, nunit_dofs_tgt)
Expand Down
65 changes: 39 additions & 26 deletions meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Mapping, Sequence, Set, Union, cast
from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence, Union, cast
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -239,11 +239,6 @@ def __init__(self,
],
bdry_grp_factory: ElementGroupFactory):
"""
:arg local_bdry_conns: A :class:`dict` mapping remote part to
`local_bdry_conn`, where `local_bdry_conn` is a
:class:`~meshmode.discretization.connection.DirectDiscretizationConnection`
that performs data exchange from the volume to the faces adjacent to
part `i_remote_part`.
:arg bdry_grp_factory: Group factory to use when creating the remote-to-local
boundary connections
"""
Expand Down Expand Up @@ -290,8 +285,11 @@ def __enter__(self):

# to know when we're done
self.pending_recv_identifiers = {
(irbi.local_part_id, irbi.remote_part_id)
for irbi in self.inter_rank_bdry_info}
(irbi.local_part_id, irbi.remote_part_id): i
for i, irbi in enumerate(self.inter_rank_bdry_info)}

assert len(self.pending_recv_identifiers) \
== len(self.inter_rank_bdry_info)

self.send_reqs = [
self._internal_mpi_comm.isend(
Expand Down Expand Up @@ -327,14 +325,22 @@ def complete_some(self):

status = MPI.Status()

# Wait for any receive
data = [self._internal_mpi_comm.recv(status=status)]
source_ranks = [status.source]

# Complete any other available receives while we're at it
while self._internal_mpi_comm.iprobe():
data.append(self._internal_mpi_comm.recv(status=status))
source_ranks.append(status.source)
# Wait for all receives
# Note: This is inefficient, but ensures a deterministic order of
# boundary setup.
nrecvs = len(self.pending_recv_identifiers)
data = [None] * nrecvs
source_ranks = [None] * nrecvs

while nrecvs > 0:
r = self._internal_mpi_comm.recv(status=status)
key = (r[1], r[0])
loc = self.pending_recv_identifiers[key]
assert data[loc] is None
assert source_ranks[loc] is None
data[loc] = r
source_ranks[loc] = status.source
nrecvs -= 1

remote_to_local_bdry_conns = {}

Expand Down Expand Up @@ -372,11 +378,11 @@ def complete_some(self):
group_factory=self.bdry_grp_factory),
remote_group_infos=remote_group_infos))

self.pending_recv_identifiers.remove((local_part_id, remote_part_id))
del self.pending_recv_identifiers[(local_part_id, remote_part_id)]

if not self.pending_recv_identifiers:
MPI.Request.waitall(self.send_reqs)
logger.info("bdry comm rank %d comm end", self.i_local_rank)
assert not self.pending_recv_identifiers
MPI.Request.waitall(self.send_reqs)
logger.info("bdry comm rank %d comm end", self.i_local_rank)

return remote_to_local_bdry_conns

Expand Down Expand Up @@ -432,30 +438,37 @@ def get_partition_by_pymetis(mesh, num_parts, *, connectivity="facial", **kwargs
return np.array(p)


def membership_list_to_map(membership_list):
def membership_list_to_map(
membership_list: np.ndarray[Any, Any]
) -> Mapping[Hashable, np.ndarray]:
"""
Convert a :class:`numpy.ndarray` that maps an index to a key into a
:class:`dict` that maps a key to a set of indices (with each set of indices
stored as a sorted :class:`numpy.ndarray`).
"""
from pytools import unique

# FIXME: not clear why the sorted() call is necessary here
return {
entry: np.where(membership_list == entry)[0]
for entry in set(membership_list)}
for entry in sorted(unique(membership_list))}


# FIXME: Move somewhere else, since it's not strictly limited to distributed?
def get_connected_parts(mesh: Mesh) -> "Set[PartID]":
def get_connected_parts(mesh: Mesh) -> "Sequence[PartID]":
"""For a local mesh part in *mesh*, determine the set of connected parts."""
assert mesh.facial_adjacency_groups is not None

return {
from pytools import unique

return tuple(unique(
grp.part_id
for fagrp_list in mesh.facial_adjacency_groups
for grp in fagrp_list
if isinstance(grp, InterPartAdjacencyGroup)}
if isinstance(grp, InterPartAdjacencyGroup)))


def get_connected_partitions(mesh: Mesh) -> "Set[PartID]":
def get_connected_partitions(mesh: Mesh) -> "Sequence[PartID]":
warn(
"get_connected_partitions is deprecated and will stop working in June 2023. "
"Use get_connected_parts instead.", DeprecationWarning, stacklevel=2)
Expand Down
15 changes: 9 additions & 6 deletions meshmode/mesh/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dataclasses import dataclass, replace
from functools import reduce
from typing import (
Callable, Dict, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union)
Callable, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union)

import numpy as np
import numpy.linalg as la
Expand Down Expand Up @@ -184,7 +184,7 @@ def _get_connected_parts(
mesh: Mesh,
part_id_to_part_index: Mapping[PartID, int],
global_elem_to_part_elem: np.ndarray,
self_part_id: PartID) -> Set[PartID]:
self_part_id: PartID) -> Sequence[PartID]:
"""
Find the parts that are connected to the current part.

Expand All @@ -196,10 +196,11 @@ def _get_connected_parts(
:func:`_compute_global_elem_to_part_elem`` for details.
:arg self_part_id: The identifier of the part currently being created.

:returns: A :class:`set` of identifiers of the neighboring parts.
:returns: A sequence of identifiers of the neighboring parts.
"""
self_part_index = part_id_to_part_index[self_part_id]

# This set is not used in a way that will cause nondeterminism.
connected_part_indices = set()

for igrp, facial_adj_list in enumerate(mesh.facial_adjacency_groups):
Expand All @@ -223,10 +224,12 @@ def _get_connected_parts(
elements_are_self & neighbors_are_other]
+ elem_base_j, 0])

return {
result = tuple(
part_id
for part_id, part_index in part_id_to_part_index.items()
if part_index in connected_part_indices}
if part_index in connected_part_indices)
assert len(set(result)) == len(result)
return result


def _create_self_to_self_adjacency_groups(
Expand Down Expand Up @@ -305,7 +308,7 @@ def _create_self_to_other_adjacency_groups(
self_part_id: PartID,
self_mesh_groups: List[MeshElementGroup],
self_mesh_group_elem_base: List[int],
connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]:
connected_parts: Sequence[PartID]) -> List[List[InterPartAdjacencyGroup]]:
"""
Create self-to-other adjacency groups for the partitioned mesh.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main():
"numpy",
"modepy>=2020.2",
"gmsh_interop",
"pytools>=2020.4.1",
"pytools>=2024.1.1",

# 2019.1 is required for the Firedrake CIs, which use an very specific
# version of Loopy.
Expand Down
9 changes: 8 additions & 1 deletion test/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
np.random.randint(mpi_comm.size, size=mesh.nelements)))

assert list(part_id_to_part.keys()) == list(range(mpi_comm.size))
parts = [part_id_to_part[i] for i in range(mpi_comm.size)]

local_mesh = mpi_comm.scatter(parts)
Expand Down Expand Up @@ -424,6 +426,11 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
conns = bdry_setup_helper.complete_some()
if not conns:
break

expected_keys = list(range(mpi_comm.size))
expected_keys.remove(mpi_comm.rank)
assert list(conns.keys()) == expected_keys

for i_remote_part, conn in conns.items():
check_connection(actx, conn)
remote_to_local_bdry_conns[i_remote_part] = conn
Expand Down Expand Up @@ -455,7 +462,7 @@ def _test_connected_parts(mpi_comm, connected_parts):
for i_remote_part in range(num_parts):
if all_connected_masks[i_remote_part][mpi_comm.rank]:
parts_connected_to_me.add(i_remote_part)
assert parts_connected_to_me == connected_parts
assert parts_connected_to_me == set(connected_parts)


# TODO
Expand Down
Loading