Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Oct 13, 2024
2 parents e7749e7 + 2096efc commit f17c7bd
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ jobs:
# have a sufficient number of cores.
mpiexec -np 2 --oversubscribe python wave/wave-op-mpi.py --lazy
mpiexec -np 2 --oversubscribe python wave/wave-op-mpi.py --numpy
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
22 changes: 15 additions & 7 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,24 @@ def bump(actx, dcoll, t=0):


def main(ctx_factory, dim=2, order=3,
visualize=False, lazy=False, use_quad=False, use_nonaffine_mesh=False,
no_diagnostics=False):
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)

visualize=False, lazy=False, numpy=False, use_quad=False,
use_nonaffine_mesh=False, no_diagnostics=False):
comm = MPI.COMM_WORLD
num_parts = comm.size

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
if lazy:
actx_class = get_reasonable_array_context_class(lazy=lazy,
distributed=True, numpy=numpy)

if numpy:
actx = actx_class(comm)
elif lazy:
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
actx = actx_class(comm, queue, mpi_base_tag=15000)
else:
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
actx = actx_class(comm, queue,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
force_device_scalars=True)
Expand Down Expand Up @@ -323,6 +328,8 @@ def rhs(t, w):
parser.add_argument("--visualize", action="store_true")
parser.add_argument("--lazy", action="store_true",
help="switch to a lazy computation mode")
parser.add_argument("--numpy", action="store_true",
help="switch to numpy-based array context")
parser.add_argument("--quad", action="store_true")
parser.add_argument("--nonaffine", action="store_true")
parser.add_argument("--no-diagnostics", action="store_true")
Expand All @@ -335,6 +342,7 @@ def rhs(t, w):
order=args.order,
visualize=args.visualize,
lazy=args.lazy,
numpy=args.numpy,
use_quad=args.quad,
use_nonaffine_mesh=args.nonaffine,
no_diagnostics=args.no_diagnostics)
Expand Down
47 changes: 43 additions & 4 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,11 @@
_HAVE_FUSION_ACTX = False


from arraycontext import ArrayContext
from arraycontext import ArrayContext, NumpyArrayContext
from arraycontext.container import ArrayContainer
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
from arraycontext.pytest import (
_PytestNumpyArrayContextFactory,
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory,
Expand Down Expand Up @@ -537,6 +538,26 @@ def clone(self):
# }}}


# {{{ distributed + numpy

class MPINumpyArrayContext(NumpyArrayContext, MPIBasedArrayContext):
"""An array context for using distributed computation with :mod:`numpy`
eager evaluation.
.. autofunction:: __init__
"""

def __init__(self, mpi_communicator) -> None:
super().__init__()

self.mpi_communicator = mpi_communicator

def clone(self):
return type(self)(self.mpi_communicator)

# }}}


# {{{ distributed + pytato array context subclasses

class MPIBasePytatoPyOpenCLArrayContext(
Expand Down Expand Up @@ -604,10 +625,19 @@ def __call__(self):
return self.actx_class(queue, allocator=alloc)


class PytestNumpyArrayContextFactory(_PytestNumpyArrayContextFactory):
actx_class = NumpyArrayContext

def __call__(self):
return self.actx_class()


register_pytest_array_context_factory("grudge.pyopencl",
PytestPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.pytato-pyopencl",
PytestPytatoPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.numpy",
PytestNumpyArrayContextFactory)

# }}}

Expand Down Expand Up @@ -639,13 +669,22 @@ def _get_single_grid_pytato_actx_class(distributed: bool) -> Type[ArrayContext]:

def get_reasonable_array_context_class(
lazy: bool = True, distributed: bool = True,
fusion: Optional[bool] = None,
fusion: Optional[bool] = None, numpy: bool = False,
) -> Type[ArrayContext]:
"""Returns a reasonable :class:`PyOpenCLArrayContext` currently
supported given the constraints of *lazy* and *distributed*."""
"""Returns a reasonable :class:`~arraycontext.ArrayContext` currently
supported given the constraints of *lazy*, *distributed*, and *numpy*."""
if fusion is None:
fusion = lazy

if numpy:
assert not (lazy or fusion)
if distributed:
actx_class = MPINumpyArrayContext
else:
actx_class = NumpyArrayContext

return actx_class

if lazy:
if fusion:
if not _HAVE_FUSION_ACTX:
Expand Down
2 changes: 1 addition & 1 deletion test/mesh_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,4 @@ def __init__(self, dim):

def get_mesh(self, resolution, mesh_order=4):
return mgen.generate_warped_rect_mesh(
dim=self.dim, order=4, nelements_side=6)
dim=self.dim, order=mesh_order, nelements_side=resolution)
4 changes: 3 additions & 1 deletion test/test_dt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@
from arraycontext import pytest_generate_tests_for_array_contexts

from grudge.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory])
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])

from grudge import make_discretization_collection
#
Expand Down
4 changes: 3 additions & 1 deletion test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from meshmode.dof_array import flat_norm

from grudge.array_context import (
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
)
Expand All @@ -42,7 +43,8 @@
logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory])
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])


# {{{ inverse metric
Expand Down

0 comments on commit f17c7bd

Please sign in to comment.