Skip to content

Commit

Permalink
Use actx module (#34)
Browse files Browse the repository at this point in the history
* Use actx module

* Use mirgecom@actx-init

* Unmispel profile

* extricate extraneous excessive extra overspecified interface argument

* Use production version of actx-init changeset.

* Go back to production
  • Loading branch information
MTCam authored Jul 6, 2023
1 parent 7382b6f commit 8f6f0c5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 32 deletions.
12 changes: 6 additions & 6 deletions driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
if lazy:
raise ValueError("Can't use lazy and profiling together.")

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
from mirgecom.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(
lazy=lazy, distributed=True, profiling=args.profile)

restart_filename = None
if args.restart_file:
Expand All @@ -78,9 +79,8 @@
print(f"Running {sys.argv[0]}\n")

from y3prediction.prediction import main
main(restart_filename=restart_filename, target_filename=target_filename,
main(actx_class, restart_filename=restart_filename,
target_filename=target_filename,
user_input_file=input_file, log_path=log_path,
use_profiling=args.profile, use_logmgr=True,
use_overintegration=args.overintegration or args.esdg,
actx_class=actx_class, casename=casename,
lazy=lazy, use_esdg=args.esdg)
casename=casename, use_esdg=args.esdg)
35 changes: 9 additions & 26 deletions y3prediction/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import logging
import sys
import numpy as np
import pyopencl as cl
import numpy.linalg as la # noqa
import pyopencl.array as cla # noqa
import math
Expand Down Expand Up @@ -203,15 +202,10 @@ class _MuDiffFluidCommTag:


@mpi_entry_point
def main(ctx_factory=cl.create_some_context,
def main(actx_class,
restart_filename=None, target_filename=None,
use_profiling=False, use_logmgr=True, user_input_file=None,
use_overintegration=False, actx_class=None, casename=None,
lazy=False, log_path="log_data", use_esdg=False):

if actx_class is None:
raise RuntimeError("Array context class missing.")

user_input_file=None, use_overintegration=False,
casename=None, log_path="log_data", use_esdg=False):
# control log messages
logger = logging.getLogger(__name__)
logger.propagate = False
Expand All @@ -231,8 +225,6 @@ def main(ctx_factory=cl.create_some_context,
h2.addFilter(f2)
logger.addHandler(h2)

cl_ctx = ctx_factory()

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
Expand All @@ -254,23 +246,14 @@ def main(ctx_factory=cl.create_some_context,
os.makedirs(log_dir)
comm.Barrier()

logmgr = initialize_logmgr(use_logmgr,
logmgr = initialize_logmgr(True,
filename=logname, mode="wu", mpi_comm=comm)

if use_profiling:
queue = cl.CommandQueue(cl_ctx,
properties=cl.command_queue_properties.PROFILING_ENABLE)
else:
queue = cl.CommandQueue(cl_ctx)

# main array context for the simulation
from mirgecom.simutil import get_reasonable_memory_pool
alloc = get_reasonable_memory_pool(cl_ctx, queue)

if lazy:
actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc)
else:
actx = actx_class(comm, queue, allocator=alloc, force_device_scalars=True)
from mirgecom.array_context import initialize_actx, actx_class_is_profiling
actx = initialize_actx(actx_class, comm)
queue = getattr(actx, "queue", None)
use_profiling = actx_class_is_profiling(actx_class)
alloc = getattr(actx, "allocator", None)

# set up driver parameters
from mirgecom.simutil import configurate
Expand Down

0 comments on commit 8f6f0c5

Please sign in to comment.