From d8b6b2306422eeb115f2066f3022d2661d84333f Mon Sep 17 00:00:00 2001 From: Tulio Date: Thu, 21 Sep 2023 19:32:03 -0500 Subject: [PATCH 1/2] Rebase and clean history --- doc/operators/fluid_wall_coupled.rst | 8 +- examples/README.md | 16 +- examples/mult_coupled_vols.geo | 77 ++ examples/multiple-coupled-volumes.py | 1184 +++++++++++++++++ examples/multiple-volumes.py | 421 ------ mirgecom/materials/carbon_fiber.py | 65 + mirgecom/multiphysics/__init__.py | 1 + .../multiphysics_coupled_fluid_wall.py | 1046 +++++++++++++++ test/test_multiphysics.py | 596 ++++++++- test/test_wallmodel.py | 340 +++++ 10 files changed, 3310 insertions(+), 444 deletions(-) create mode 100644 examples/mult_coupled_vols.geo create mode 100644 examples/multiple-coupled-volumes.py delete mode 100644 examples/multiple-volumes.py create mode 100644 mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py create mode 100644 test/test_wallmodel.py diff --git a/doc/operators/fluid_wall_coupled.rst b/doc/operators/fluid_wall_coupled.rst index e3bcced1e..5f21b0da2 100644 --- a/doc/operators/fluid_wall_coupled.rst +++ b/doc/operators/fluid_wall_coupled.rst @@ -5,10 +5,16 @@ Coupled Fluid-Wall Operators 2) :class:`~mirgecom.multiphysics.phenolics_coupled_fluid_wall`. +3) :class:`~mirgecom.multiphysics.multiphysics_coupled_fluid_wall`. + Heat conduction coupling ^^^^^^^^^^^^^^^^^^^^^^^^ .. automodule:: mirgecom.multiphysics.thermally_coupled_fluid_wall +Phenolics coupling +^^^^^^^^^^^^^^^^^^ +.. automodule:: mirgecom.multiphysics.phenolics_coupled_fluid_wall + Porous flow coupling ^^^^^^^^^^^^^^^^^^^^ -.. automodule:: mirgecom.multiphysics.phenolics_coupled_fluid_wall +.. automodule:: mirgecom.multiphysics.multiphysics_coupled_fluid_wall diff --git a/examples/README.md b/examples/README.md index 0df93f6cb..28f5a96f2 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,17 +4,19 @@ This directory has a collection of examples to demonstrate and test *MIRGE-Com* capabilities. All of the example exercise some unique feature of *MIRGE-Com*. The examples and the unique features they exercise are as follows: -- `autoignition.py`: Chemistry verification case with Pyrometheus -- `heat-source.py`: Diffusion operator +- `scalar-advdiff.py`: Scalar advection-diffusion verification case - `lump.py`: Lump advection, advection verification case -- `mixture.py`: Mixture EOS with Pyrometheus - `scalar-lump.py`: Scalar component lump advection verification case -- `pulse.py`: Acoustic pulse in a box, wall boundary test case -- `sod.py`: Sod's shock case: Fluid test case with strong shock -- `vortex.py`: Isentropic vortex advection: outflow boundaries, verification +- `pulse.py`: Acoustic pulse in a box, outflow boundaries test case +- `vortex.py`: Isentropic vortex advection: prescribed boundaries, verification - `hotplate.py`: Isothermal BC verification (prescribed exact soln) +- `sod.py`: Sod's shock tube case: Fluid test case with strong shock - `doublemach.py`: AV test case - `poiseuille.py`: Poiseuille flow verification case - `poiseuille-multispecies.py`: Poiseuille flow with passive scalars -- `scalar-advdiff.py`: Scalar advection-diffusion verification case +- `autoignition.py`: Chemistry verification case with Pyrometheus +- `mixture.py`: Mixture EOS with Pyrometheus - `combozzle.py`: Prediction-relevant testing, kitchen sink, many options +- `heat-source.py`: Diffusion operator +- `thermally-coupled`: Fluid-solid interaction for heat transfer +- `multiple-coupled-volumes`: Fluid-porous solid interaction diff --git a/examples/mult_coupled_vols.geo b/examples/mult_coupled_vols.geo new file mode 100644 index 000000000..1f4cfdacb --- /dev/null +++ b/examples/mult_coupled_vols.geo @@ -0,0 +1,77 @@ +Point( 1) = {-0.50,-0.05,0.0}; +Point( 2) = {-0.50, 0.05,0.0}; +Point( 3) = {-0.20,-0.05,0.0}; +Point( 4) = {-0.20, 0.05,0.0}; +Point( 5) = {-0.10,-0.05,0.0}; +Point( 6) = {-0.10, 0.05,0.0}; +Point( 7) = { 0.00,-0.05,0.0}; +Point( 8) = { 0.00, 0.05,0.0}; +Point( 9) = {+0.10,-0.05,0.0}; +Point(10) = {+0.10, 0.05,0.0}; +Point(11) = {+0.20,-0.05,0.0}; +Point(12) = {+0.20, 0.05,0.0}; +Point(13) = {+0.50,-0.05,0.0}; +Point(14) = {+0.50, 0.05,0.0}; + +Line( 1) = {1,2}; + +Line( 2) = {2,4}; +Line( 3) = {4,6}; +Line( 4) = {6,8}; +Line( 5) = {8,10}; +Line( 6) = {10,12}; +Line( 7) = {12,14}; +Line( 8) = {14,13}; +Line( 9) = {13,11}; +Line(10) = {11,9}; +Line(11) = {9,7}; +Line(12) = {7,5}; +Line(13) = {5,3}; +Line(14) = {3,1}; + +Line(15) = {4,3}; +Line(16) = {6,5}; +Line(18) = {10,9}; +Line(19) = {12,11}; + +Transfinite Line {1} = 6 Using Progression 1.0; +Transfinite Line {2} = 16 Using Progression 1.0; +Transfinite Line {3} = 6 Using Progression 1.0; +Transfinite Line {4} = 6 Using Progression 1.0; +Transfinite Line {5} = 6 Using Progression 1.0; +Transfinite Line {6} = 6 Using Progression 1.0; +Transfinite Line {7} = 16 Using Progression 1.0; +Transfinite Line {8} = 6 Using Progression 1.0; +Transfinite Line {9} = 16 Using Progression 1.0; +Transfinite Line {10} = 6 Using Progression 1.0; +Transfinite Line {11} = 6 Using Progression 1.0; +Transfinite Line {12} = 6 Using Progression 1.0; +Transfinite Line {13} = 6 Using Progression 1.0; +Transfinite Line {14} = 16 Using Progression 1.0; +Transfinite Line {15} = 6 Using Progression 1.0; +Transfinite Line {16} = 6 Using Progression 1.0; +Transfinite Line {17} = 6 Using Progression 1.0; +Transfinite Line {18} = 6 Using Progression 1.0; +Transfinite Line {19} = 6 Using Progression 1.0; + +Line Loop(11) = { 1,2,15,14}; +Line Loop(12) = {-15,3,16,13}; +Line Loop(13) = {-16,4,5,18,11,12}; +Line Loop(14) = {-18,6,19,10}; +Line Loop(15) = {-19,7,8,9}; + +Plane Surface(11) = {11}; +Plane Surface(12) = {12}; +Plane Surface(13) = {13}; +Plane Surface(14) = {14}; +Plane Surface(15) = {15}; + +Physical Surface("Fluid") = {11,15}; +Physical Surface("Sample") = {12,14}; +Physical Surface("Holder") = {13}; + +Physical Curve("Fluid Hot") = {1}; +Physical Curve("Fluid Cold") = {8}; +Physical Curve("Fluid Sides") = {2,7,9,14}; +Physical Curve("Sample Sides") = {3,6,10,13}; +Physical Curve("Holder Sides") = {4,5,11,12}; diff --git a/examples/multiple-coupled-volumes.py b/examples/multiple-coupled-volumes.py new file mode 100644 index 000000000..e0fd499e5 --- /dev/null +++ b/examples/multiple-coupled-volumes.py @@ -0,0 +1,1184 @@ +"""Demonstrates coupling of multiple domains.""" + +__copyright__ = """ +Copyright (C) 2023 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import logging +import sys +import gc +import os +import numpy as np +from warnings import warn +from grudge.shortcuts import make_visualizer +from grudge.dof_desc import ( + DOFDesc, VolumeDomainTag +) +from grudge.dof_desc import DISCR_TAG_BASE, DISCR_TAG_QUAD + +from mirgecom.discretization import create_discretization_collection +from mirgecom.utils import force_evaluation +from mirgecom.simutil import ( + check_step, distribute_mesh, write_visfile, + check_naninf_local, global_reduce +) +from mirgecom.restart import write_restart_file +from mirgecom.io import make_init_message +from mirgecom.mpi import mpi_entry_point +from mirgecom.steppers import advance_state +from mirgecom.boundary import IsothermalWallBoundary, AdiabaticSlipBoundary +from mirgecom.fluid import make_conserved +from mirgecom.transport import SimpleTransport +import cantera +from mirgecom.eos import PyrometheusMixture +from mirgecom.gas_model import ( + GasModel, + make_fluid_state, + make_operator_fluid_states +) +from mirgecom.logging_quantities import ( + initialize_logmgr, + logmgr_add_cl_device_info, + logmgr_set_time, + logmgr_add_device_memory_usage, +) +from mirgecom.navierstokes import ( + grad_t_operator, + grad_cv_operator, + ns_operator +) +from mirgecom.multiphysics.multiphysics_coupled_fluid_wall import ( + add_interface_boundaries as add_multiphysics_interface_boundaries, + add_interface_boundaries_no_grad as add_multiphysics_interface_boundaries_no_grad +) +from mirgecom.multiphysics.thermally_coupled_fluid_wall import ( + add_interface_boundaries as add_thermal_interface_boundaries, + add_interface_boundaries_no_grad as add_thermal_interface_boundaries_no_grad +) +from mirgecom.diffusion import ( + diffusion_operator, + grad_operator as wall_grad_t_operator, + NeumannDiffusionBoundary +) +from mirgecom.wall_model import ( + SolidWallConservedVars, + SolidWallDependentVars, + SolidWallState, + SolidWallModel, + PorousWallTransport, + PorousFlowModel +) +from mirgecom.mechanisms import get_mechanism_input +from mirgecom.thermochemistry import get_pyrometheus_wrapper_class_from_cantera +from mirgecom.limiter import bound_preserving_limiter + +from logpyle import IntervalTimer, set_dt + +from pytools.obj_array import make_obj_array + +######################################################################### + + +class _FluidGradCVTag: + pass + + +class _FluidGradTempTag: + pass + + +class _SampleGradCVTag: + pass + + +class _SampleGradTempTag: + pass + + +class _HolderGradTempTag: + pass + + +class _FluidOperatorTag: + pass + + +class _SampleOperatorTag: + pass + + +class _HolderOperatorTag: + pass + + +class _FluidOpStatesTag: + pass + + +class _WallOpStatesTag: + pass + + +class FluidInitializer: + + def __init__(self, species_left, species_right): + self._yl = species_left + self._yr = species_right + + def __call__(self, x_vec, gas_model): + + actx = x_vec[0].array_context + eos = gas_model.eos + + hot_temp = 2000.0 + cold_temp = 300.0 + + aux = 0.5*(1.0 - actx.np.tanh(1.0/(.01)*(x_vec[0] + 0.25))) + y1 = self._yl*aux + y2 = self._yr*(1.0-aux) + y = y1+y2 + + pressure = 101325.0 + x_vec[0]*0.0 + temperature = cold_temp + \ + (hot_temp - cold_temp)*.5*(1. - actx.np.tanh(1.0/.01*(x_vec[0]+.25))) + + mass = eos.get_density(pressure, temperature, species_mass_fractions=y) + momentum = make_obj_array([0.0*x_vec[0], 0.0*x_vec[0]]) + specmass = mass * y + energy = mass * eos.get_internal_energy(temperature, + species_mass_fractions=y) + + return make_conserved(dim=2, mass=mass, energy=energy, + momentum=momentum, species_mass=specmass) + + +class HolderWallModel: + """Model for calculating wall quantities.""" + def __init__(self, density_func, enthalpy_func, heat_capacity_func, + thermal_conductivity_func): + self._density_func = density_func + self._enthalpy_func = enthalpy_func + self._heat_capacity_func = heat_capacity_func + self._thermal_conductivity_func = thermal_conductivity_func + + def density(self): + return self._density_func() + + def heat_capacity(self): + return self._heat_capacity_func() + + def enthalpy(self, temperature): + return self._enthalpy_func(temperature) + + def thermal_diffusivity(self, mass, temperature, + thermal_conductivity=None): + if thermal_conductivity is None: + thermal_conductivity = self.thermal_conductivity() + return thermal_conductivity/(mass * self.heat_capacity()) + + def thermal_conductivity(self): + return self._thermal_conductivity_func() + + +class SingleLevelFilter(logging.Filter): + def __init__(self, passlevel, reject): + self.passlevel = passlevel + self.reject = reject + + def filter(self, record): + if self.reject: + return (record.levelno != self.passlevel) + else: + return (record.levelno == self.passlevel) + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class MyRuntimeError(RuntimeError): + """Simple exception to kill the simulation.""" + + pass + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +@mpi_entry_point +def main(actx_class, use_logmgr=True, casename=None, restart_filename=None): + + from mpi4py import MPI + comm = MPI.COMM_WORLD + rank = 0 + rank = comm.Get_rank() + nparts = comm.Get_size() + + from mirgecom.array_context import initialize_actx, actx_class_is_profiling + actx = initialize_actx(actx_class, comm) + queue = getattr(actx, "queue", None) + use_profiling = actx_class_is_profiling(actx_class) + + # ~~~~~~~~~~~~~~~~~~ + + rst_path = "./" + viz_path = "./" + vizname = viz_path+casename + rst_pattern = rst_path+"{cname}-{step:06d}-{rank:04d}.pkl" + + # default i/o frequencies + nviz = 100 + nrestart = 25000 + nhealth = 1 + nstatus = 100 + ngarbage = 50 + + # default timestepping control + integrator = "ssprk43" + current_dt = 2.5e-2 + t_final = 2.5e-1 + + local_dt = False + constant_cfl = False + current_cfl = 0.4 + + # discretization and model control + order = 1 + use_overintegration = False + + # wall stuff + temp_wall = 300.0 + wall_penalty_amount = 1.0 + + use_radiation = True # or False + emissivity = 1.0 + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + dim = 2 + + if integrator == "ssprk43": + from mirgecom.integrators.ssprk import ssprk43_step + timestepper = ssprk43_step + force_eval = True + + if rank == 0: + print("\n#### Simulation control data: ####") + print(f"\tnviz = {nviz}") + print(f"\tnrestart = {nrestart}") + print(f"\tnhealth = {nhealth}") + print(f"\tnstatus = {nstatus}") + print(f"\tcurrent_dt = {current_dt}") + print(f"\tt_final = {t_final}") + print(f"\torder = {order}") + print(f"\tTime integration = {integrator}") + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + local_path = os.path.dirname(os.path.abspath(__file__)) + "/" + mesh_filename = "mult_coupled_vols-v2.msh" + geo_filename = "mult_coupled_vols.geo" + omesh_filename = "mult_coupled_vols.msh" + + mesh_path = local_path + mesh_filename + geo_path = local_path + geo_filename + omesh_path = local_path + omesh_filename + + restart_step = None + if restart_file is None: + + if rank == 0: + os.system(f"rm -rf {omesh_path} {mesh_path}") + os.system(f"gmsh {geo_path} -2 {omesh_path}") + os.system(f"gmsh {omesh_path} -save -format msh2 -o {mesh_path}") + + comm.Barrier() + + current_step = 0 + first_step = current_step + 0 + current_t = 0.0 + + if rank == 0: + print(f"Reading mesh from {mesh_path}") + + def get_mesh_data(): + from meshmode.mesh.io import read_gmsh + mesh, tag_to_elements = read_gmsh( + mesh_path, force_ambient_dim=dim, + return_tag_to_elements_map=True) + volume_to_tags = { + "fluid": ["Fluid"], + "sample": ["Sample"], + "holder": ["Holder"] + } + return mesh, tag_to_elements, volume_to_tags + + volume_to_local_mesh_data, global_nelements = distribute_mesh( + comm, get_mesh_data) + + else: # Restart + from mirgecom.restart import read_restart_data + restart_data = read_restart_data(actx, restart_file) + restart_step = restart_data["step"] + volume_to_local_mesh_data = restart_data["volume_to_local_mesh_data"] + global_nelements = restart_data["global_nelements"] + restart_order = int(restart_data["order"]) + first_step = restart_step+0 + + assert comm.Get_size() == restart_data["num_parts"] + + local_nelements = ( + + volume_to_local_mesh_data["fluid"][0].nelements + + volume_to_local_mesh_data["sample"][0].nelements + + volume_to_local_mesh_data["holder"][0].nelements) + + dcoll = create_discretization_collection( + actx, + volume_meshes={ + vol: mesh + for vol, (mesh, _) in volume_to_local_mesh_data.items()}, + order=order) + + if use_overintegration: + quadrature_tag = DISCR_TAG_QUAD + else: + quadrature_tag = DISCR_TAG_BASE + + if rank == 0: + logger.info("Done making discretization") + + dd_vol_fluid = DOFDesc(VolumeDomainTag("fluid"), DISCR_TAG_BASE) + dd_vol_sample = DOFDesc(VolumeDomainTag("sample"), DISCR_TAG_BASE) + dd_vol_holder = DOFDesc(VolumeDomainTag("holder"), DISCR_TAG_BASE) + + fluid_nodes = actx.thaw(dcoll.nodes(dd_vol_fluid)) + sample_nodes = actx.thaw(dcoll.nodes(dd_vol_sample)) + holder_nodes = actx.thaw(dcoll.nodes(dd_vol_holder)) + + fluid_zeros = force_evaluation(actx, fluid_nodes[0]*0.0) + sample_zeros = force_evaluation(actx, sample_nodes[0]*0.0) + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # {{{ Set up initial state using Cantera + + # Use Cantera for initialization + mech_input = get_mechanism_input("air_3sp") + + cantera_soln = cantera.Solution(name="gas", yaml=mech_input) + nspecies = cantera_soln.n_species + + temp_cantera = 300.0 + pres_cantera = cantera.one_atm # pylint: disable=no-member + + # Set Cantera internal gas temperature, pressure, and mole fractios + x_left = np.zeros(nspecies) + x_left[cantera_soln.species_index("O2")] = 0.0 + x_left[cantera_soln.species_index("N2")] = 1.0 + + cantera_soln.TPX = temp_cantera, pres_cantera, x_left + y_left = cantera_soln.Y + + x_right = np.zeros(nspecies) + x_right[cantera_soln.species_index("O2")] = 1.0 + x_right[cantera_soln.species_index("N2")] = 0.0 + + cantera_soln.TPX = temp_cantera, pres_cantera, x_right + y_right = cantera_soln.Y + + # }}} + + # {{{ Create Pyrometheus thermochemistry object & EOS + + # Import Pyrometheus EOS + pyrometheus_mechanism = get_pyrometheus_wrapper_class_from_cantera( + cantera_soln, temperature_niter=3)(actx.np) + + temperature_seed = 300.0 + eos = PyrometheusMixture(pyrometheus_mechanism, + temperature_guess=temperature_seed) + + species_names = pyrometheus_mechanism.species_names + print(f"Pyrometheus mechanism species names {species_names}") + + # }}} + + # {{{ Initialize transport model + + fluid_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=0.1, + species_diffusivity=np.zeros(nspecies,) + 0.001) + + base_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=0.2, + species_diffusivity=np.zeros(nspecies,) + 0.001) + sample_transport = PorousWallTransport(base_transport=base_transport) + + # }}} + + # ~~~~~~~~~~~~~~ + + # {{{ Initialize wall model + + import mirgecom.materials.carbon_fiber as my_material + sample_density = 0.1*1600.0 + sample_zeros + fiber = my_material.FiberEOS(char_mass=0.0, virgin_mass=160.0) + + # }}} + + # ~~~~~~~~~~~~~~ + + # {{{ Initialize wall model + + wall_holder_rho = 10.0 + wall_holder_cp = 1000.0 + wall_holder_kappa = 2.00 + + def _get_holder_enthalpy(temperature, **kwargs): + return wall_holder_cp * temperature + + def _get_holder_heat_capacity(**kwargs): + return wall_holder_cp + + def _get_holder_thermal_conductivity(**kwargs): + return wall_holder_kappa + + holder_wall_model = SolidWallModel( + enthalpy_func=_get_holder_enthalpy, + heat_capacity_func=_get_holder_heat_capacity, + thermal_conductivity_func=_get_holder_thermal_conductivity) + + # }}} + + gas_model_fluid = GasModel(eos=eos, transport=fluid_transport) + gas_model_sample = PorousFlowModel(eos=eos, wall_eos=fiber, + transport=sample_transport) + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def _limit_fluid_cv(cv, pressure, temperature, dd=None): + + # limit species + spec_lim = make_obj_array([ + bound_preserving_limiter(dcoll, cv.species_mass_fractions[i], + mmin=0.0, mmax=1.0, modify_average=True, dd=dd) + for i in range(nspecies) + ]) + + # normalize to ensure sum_Yi = 1.0 + aux = cv.mass*0.0 + for i in range(0, nspecies): + aux = aux + spec_lim[i] + spec_lim = spec_lim/aux + + # recompute density + mass_lim = gas_model_fluid.eos.get_density(pressure=pressure, + temperature=temperature, species_mass_fractions=spec_lim) + + # recompute energy + energy_lim = mass_lim*(gas_model_fluid.eos.get_internal_energy( + temperature, species_mass_fractions=spec_lim) + + 0.5*np.dot(cv.velocity, cv.velocity) + ) + + # make a new CV with the limited variables + return make_conserved(dim=dim, mass=mass_lim, energy=energy_lim, + momentum=mass_lim*cv.velocity, species_mass=mass_lim*spec_lim) + + def _get_fluid_state(cv, temp_seed): + return make_fluid_state(cv=cv, gas_model=gas_model_fluid, + temperature_seed=temp_seed, limiter_func=_limit_fluid_cv, + limiter_dd=dd_vol_fluid) + + get_fluid_state = actx.compile(_get_fluid_state) + + # ~~~~~~~~~~ + + def _limit_sample_cv(cv, wv, pressure, temperature, dd=None): + + # limit species + spec_lim = make_obj_array([ + bound_preserving_limiter(dcoll, cv.species_mass_fractions[i], + mmin=0.0, mmax=1.0, modify_average=True, dd=dd) + for i in range(nspecies) + ]) + + # normalize to ensure sum_Yi = 1.0 + aux = cv.mass*0.0 + for i in range(0, nspecies): + aux = aux + spec_lim[i] + spec_lim = spec_lim/aux + # recompute gas density + mass_lim = wv.void_fraction*gas_model_sample.eos.get_density( + pressure=pressure, temperature=temperature, + species_mass_fractions=spec_lim) + + # recompute gas energy + energy_gas = mass_lim*( + gas_model_sample.eos.get_internal_energy( + temperature, species_mass_fractions=spec_lim) + + 0.5*np.dot(cv.velocity, cv.velocity) + ) + + # compute solid energy + energy_solid = \ + wv.density*gas_model_sample.wall_eos.enthalpy(temperature, wv.tau) + + # the total energy is a composition of both solid and gas + energy = energy_gas + energy_solid + + # make a new CV with the limited variables + return make_conserved(dim=dim, mass=mass_lim, energy=energy, + momentum=mass_lim*cv.velocity, species_mass=mass_lim*spec_lim) + + def _get_sample_state(cv, wv, temp_seed): + return make_fluid_state(cv=cv, gas_model=gas_model_sample, + material_densities=wv, + temperature_seed=temp_seed, + limiter_func=_limit_sample_cv, limiter_dd=dd_vol_sample) + + get_sample_state = actx.compile(_get_sample_state) + + # ~~~~~~~~~~ + + def _get_holder_state(wv): + dep_vars = holder_wall_model.dependent_vars(wv) + return SolidWallState(cv=wv, dv=dep_vars) + + get_holder_state = actx.compile(_get_holder_state) + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + from mirgecom.materials.initializer import ( + PorousWallInitializer, + SolidWallInitializer + ) + + fluid_init = FluidInitializer(species_left=y_left, species_right=y_right) + + sample_init = PorousWallInitializer(pressure=101325.0, temperature=300.0, + species=y_right, material_densities=sample_density) + + holder_init = SolidWallInitializer(temperature=300.0, + material_densities=wall_holder_rho) + +############################################################################## + + if restart_file is None: + if rank == 0: + logging.info("Initializing soln.") + + fluid_tseed = temperature_seed + fluid_zeros + sample_tseed = temp_wall + sample_zeros + + fluid_cv = fluid_init(fluid_nodes, gas_model_fluid) + sample_cv = sample_init(2, sample_nodes, gas_model_sample) + holder_cv = holder_init(holder_nodes, holder_wall_model) + + else: + current_step = restart_step + current_t = restart_data["t"] + + if rank == 0: + logger.info("Restarting soln.") + + if restart_order != order: + restart_dcoll = create_discretization_collection( + actx, + volume_meshes={ + vol: mesh + for vol, (mesh, _) in volume_to_local_mesh_data.items()}, + order=restart_order) + from meshmode.discretization.connection import make_same_mesh_connection + fluid_connection = make_same_mesh_connection( + actx, + dcoll.discr_from_dd(dd_vol_fluid), + restart_dcoll.discr_from_dd(dd_vol_fluid) + ) + sample_connection = make_same_mesh_connection( + actx, + dcoll.discr_from_dd(dd_vol_sample), + restart_dcoll.discr_from_dd(dd_vol_sample) + ) + holder_connection = make_same_mesh_connection( + actx, + dcoll.discr_from_dd(dd_vol_holder), + restart_dcoll.discr_from_dd(dd_vol_holder) + ) + fluid_cv = fluid_connection(restart_data["fluid_cv"]) + fluid_tseed = fluid_connection(restart_data["fluid_temperature_seed"]) + sample_cv = sample_connection(restart_data["sample_cv"]) + sample_tseed = sample_connection(restart_data["wall_temperature_seed"]) + sample_density = sample_connection(restart_data["sample_density"]) + holder_cv = holder_connection(restart_data["holder_cv"]) + else: + fluid_cv = restart_data["fluid_cv"] + fluid_tseed = restart_data["fluid_temperature_seed"] + sample_cv = restart_data["sample_cv"] + sample_tseed = restart_data["sample_temperature_seed"] + sample_density = restart_data["sample_density"] + holder_cv = restart_data["holder_cv"] + + fluid_cv = force_evaluation(actx, fluid_cv) + fluid_tseed = force_evaluation(actx, fluid_tseed) + fluid_state = get_fluid_state(fluid_cv, fluid_tseed) + + sample_cv = force_evaluation(actx, sample_cv) + sample_tseed = force_evaluation(actx, sample_tseed) + sample_density = force_evaluation(actx, sample_density) + sample_state = get_sample_state(sample_cv, sample_density, sample_tseed) + + holder_cv = force_evaluation(actx, holder_cv) + holder_state = get_holder_state(holder_cv) + +############################################################################## + + original_casename = casename + casename = f"{casename}-d{dim}p{order}e{global_nelements}n{nparts}" + logmgr = initialize_logmgr(use_logmgr, filename=(f"{casename}.sqlite"), + mode="wo", mpi_comm=comm) + + vis_timer = None + if logmgr: + logmgr_add_cl_device_info(logmgr, queue) + logmgr_add_device_memory_usage(logmgr, queue) + logmgr_set_time(logmgr, current_step, current_t) + + logmgr.add_watches([ + ("step.max", "step = {value}, "), + ("dt.max", "dt: {value:1.5e} s, "), + ("t_sim.max", "sim time: {value:1.5e} s, "), + ("t_step.max", "--- step walltime: {value:5g} s\n") + ]) + + try: + logmgr.add_watches(["memory_usage_python.max", + "memory_usage_gpu.max"]) + except KeyError: + pass + + if use_profiling: + logmgr.add_watches(["pyopencl_array_time.max"]) + + vis_timer = IntervalTimer("t_vis", "Time spent visualizing") + logmgr.add_quantity(vis_timer) + + gc_timer = IntervalTimer("t_gc", "Time spent garbage collecting") + logmgr.add_quantity(gc_timer) + +############################################################################## + + fluid_boundaries = { + dd_vol_fluid.trace("Fluid Hot").domain_tag: + IsothermalWallBoundary(wall_temperature=2000.0), + dd_vol_fluid.trace("Fluid Cold").domain_tag: + IsothermalWallBoundary(wall_temperature=300.0), + dd_vol_fluid.trace("Fluid Sides").domain_tag: AdiabaticSlipBoundary(), + } + + # ~~~~~~~~~~ + sample_boundaries = { + dd_vol_sample.trace("Sample Sides").domain_tag: AdiabaticSlipBoundary() + } + + # ~~~~~~~~~~ + holder_boundaries = { + dd_vol_holder.trace("Holder Sides").domain_tag: NeumannDiffusionBoundary(0.0) + } + +############################################################################## + + fluid_visualizer = make_visualizer(dcoll, volume_dd=dd_vol_fluid) + sample_visualizer = make_visualizer(dcoll, volume_dd=dd_vol_sample) + holder_visualizer = make_visualizer(dcoll, volume_dd=dd_vol_holder) + + initname = original_casename + eosname = eos.__class__.__name__ + init_message = make_init_message(dim=dim, order=order, + nelements=local_nelements, global_nelements=global_nelements, + dt=current_dt, t_final=t_final, nstatus=nstatus, nviz=nviz, + t_initial=current_t, cfl=current_cfl, constant_cfl=constant_cfl, + initname=initname, eosname=eosname, casename=casename) + + if rank == 0: + logger.info(init_message) + +############################################################################## + + def my_write_viz(step, t, dt, fluid_state, sample_state, holder_state): + + fluid_viz_fields = [ + ("rho_g", fluid_state.cv.mass), + ("rhoU_g", fluid_state.cv.momentum), + ("rhoE_g", fluid_state.cv.energy), + ("pressure", fluid_state.pressure), + ("temperature", fluid_state.temperature), + ("Vx", fluid_state.velocity[0]), + ("Vy", fluid_state.velocity[1]), + ("dt", dt[0] if local_dt else None)] + fluid_viz_fields.extend( + ("Y_"+species_names[i], fluid_state.cv.species_mass_fractions[i]) + for i in range(nspecies)) + + sample_viz_fields = [ + ("rho_g", sample_state.cv.mass), + ("rhoU_g", sample_state.cv.momentum), + ("rhoE_b", sample_state.cv.energy), + ("pressure", sample_state.pressure), + ("temperature", sample_state.temperature), + ("sample_mass", sample_state.wv.material_densities), + ("Vx", sample_state.velocity[0]), + ("Vy", sample_state.velocity[1]), + ("kappa", sample_state.thermal_conductivity)] + sample_viz_fields.extend( + ("Y_"+species_names[i], sample_state.cv.species_mass_fractions[i]) + for i in range(nspecies)) + + holder_viz_fields = [ + ("holder_mass", holder_state.cv.mass), + ("rhoE_s", holder_state.cv.energy), + ("temperature", holder_state.dv.temperature), + ("kappa", holder_state.dv.thermal_conductivity)] + + write_visfile(dcoll, fluid_viz_fields, fluid_visualizer, + vizname=vizname+"-fluid", step=step, t=t, overwrite=True, comm=comm) + write_visfile(dcoll, sample_viz_fields, sample_visualizer, + vizname=vizname+"-sample", step=step, t=t, overwrite=True, comm=comm) + write_visfile(dcoll, holder_viz_fields, holder_visualizer, + vizname=vizname+"-holder", step=step, t=t, overwrite=True, comm=comm) + + def my_write_restart(step, t, fluid_state, sample_state, holder_state): + if rank == 0: + print("Writing restart file...") + + restart_fname = rst_pattern.format(cname=casename, step=step, + rank=rank) + if restart_fname != restart_filename: + restart_data = { + "volume_to_local_mesh_data": volume_to_local_mesh_data, + "fluid_cv": fluid_state.cv, + "fluid_temperature_seed": fluid_state.temperature, + "sample_cv": sample_state.cv, + "sample_density": sample_state.wv.material_densities, + "sample_temperature_seed": sample_state.temperature, + "holder_cv": holder_state.cv, + "holder_temperature_seed": holder_state.dv.temperature, + "nspecies": nspecies, + "t": t, + "step": step, + "order": order, + "global_nelements": global_nelements, + "num_parts": nparts + } + + write_restart_file(actx, restart_data, restart_fname, comm) + +########################################################################## + + def my_health_check(cv, dv): + health_error = False + pressure = force_evaluation(actx, dv.pressure) + temperature = force_evaluation(actx, dv.temperature) + + if global_reduce(check_naninf_local(dcoll, "vol", pressure), op="lor"): + health_error = True + logger.info(f"{rank=}: NANs/Infs in pressure data.") + + if global_reduce(check_naninf_local(dcoll, "vol", temperature), op="lor"): + health_error = True + logger.info(f"{rank=}: NANs/Infs in temperature data.") + + return health_error + +############################################################################## + + def my_pre_step(step, t, dt, state): + + if logmgr: + logmgr.tick_before() + + fluid_cv, fluid_tseed, \ + sample_cv, sample_tseed, sample_density, \ + holder_cv = state + + fluid_cv = force_evaluation(actx, fluid_cv) + fluid_tseed = force_evaluation(actx, fluid_tseed) + sample_cv = force_evaluation(actx, sample_cv) + sample_tseed = force_evaluation(actx, sample_tseed) + sample_density = force_evaluation(actx, sample_density) + holder_cv = force_evaluation(actx, holder_cv) + + # construct species-limited fluid state + fluid_state = get_fluid_state(fluid_cv, fluid_tseed) + fluid_cv = fluid_state.cv + + # construct species-limited solid state + sample_state = get_sample_state(sample_cv, sample_density, sample_tseed) + sample_cv = sample_state.cv + + # construct species-limited solid state + holder_state = get_holder_state(holder_cv) + + try: + state = make_obj_array([ + fluid_cv, fluid_state.temperature, + sample_cv, sample_state.temperature, sample_density, + holder_cv]) + + do_garbage = check_step(step=step, interval=ngarbage) + do_viz = check_step(step=step, interval=nviz) + do_restart = check_step(step=step, interval=nrestart) + do_health = check_step(step=step, interval=nhealth) + + if do_garbage: + with gc_timer.start_sub_timer(): + warn("Running gc.collect() to work around memory growth issue ") + gc.collect() + + if do_health: + health_errors = global_reduce( + my_health_check(fluid_state.cv, fluid_state.dv), op="lor") + if health_errors: + if rank == 0: + logger.info("Fluid solution failed health check.") + raise MyRuntimeError("Failed simulation health check.") + + if do_viz: + my_write_viz(step=step, t=t, dt=dt, fluid_state=fluid_state, + sample_state=sample_state, holder_state=holder_state) + + if do_restart: + my_write_restart(step, t, fluid_state, sample_state, holder_state) + + except MyRuntimeError: + if rank == 0: + logger.info("Errors detected; attempting graceful exit.") + my_write_viz(step=step, t=t, dt=dt, fluid_state=fluid_state, + sample_state=sample_state, holder_state=holder_state) + raise + + return state, dt + + def my_rhs(time, state): + + fluid_cv, fluid_tseed, \ + sample_cv, sample_tseed, sample_density, \ + holder_cv = state + + # construct species-limited fluid state + fluid_state = make_fluid_state(cv=fluid_cv, gas_model=gas_model_fluid, + temperature_seed=fluid_tseed, + limiter_func=_limit_fluid_cv, limiter_dd=dd_vol_fluid) + fluid_cv = fluid_state.cv + + # construct species-limited solid state + sample_state = make_fluid_state(cv=sample_cv, gas_model=gas_model_sample, + material_densities=sample_density, + temperature_seed=sample_tseed, + limiter_func=_limit_sample_cv, limiter_dd=dd_vol_sample) + sample_cv = sample_state.cv + + # construct species-limited solid state + holder_state = _get_holder_state(holder_cv) + holder_cv = holder_state.cv + # ~~~~~~~~~~~~~ + + fluid_all_boundaries_no_grad, sample_all_boundaries_no_grad = \ + add_multiphysics_interface_boundaries_no_grad( + dcoll, dd_vol_fluid, dd_vol_sample, + gas_model_fluid, gas_model_sample, + fluid_state, sample_state, + fluid_boundaries, sample_boundaries, + limiter_func_fluid=_limit_fluid_cv, + limiter_func_wall=_limit_sample_cv, + interface_noslip=True, interface_radiation=use_radiation) + + fluid_all_boundaries_no_grad, holder_all_boundaries_no_grad = \ + add_thermal_interface_boundaries_no_grad( + dcoll, gas_model_fluid, # FIXME remove gas_model + dd_vol_fluid, dd_vol_holder, + fluid_state, holder_state.dv.thermal_conductivity, + holder_state.dv.temperature, + fluid_all_boundaries_no_grad, holder_boundaries, + interface_noslip=True, interface_radiation=use_radiation) + + sample_all_boundaries_no_grad, holder_all_boundaries_no_grad = \ + add_thermal_interface_boundaries_no_grad( + dcoll, gas_model_sample, # FIXME remove gas_model + dd_vol_sample, dd_vol_holder, + sample_state, holder_state.dv.thermal_conductivity, + holder_state.dv.temperature, + sample_all_boundaries_no_grad, holder_all_boundaries_no_grad, + interface_noslip=True, interface_radiation=False) + + # ~~~~~~~~~~~~~~ + + fluid_operator_states_quad = make_operator_fluid_states( + dcoll, fluid_state, gas_model_fluid, fluid_all_boundaries_no_grad, + quadrature_tag, dd=dd_vol_fluid, comm_tag=_FluidOpStatesTag, + limiter_func=_limit_fluid_cv) + + sample_operator_states_quad = make_operator_fluid_states( + dcoll, sample_state, gas_model_sample, sample_all_boundaries_no_grad, + quadrature_tag, dd=dd_vol_sample, comm_tag=_WallOpStatesTag, + limiter_func=_limit_sample_cv) + + # ~~~~~~~~~~~~~~ + + # fluid grad CV + fluid_grad_cv = grad_cv_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + time=time, quadrature_tag=quadrature_tag, dd=dd_vol_fluid, + operator_states_quad=fluid_operator_states_quad, + comm_tag=_FluidGradCVTag + ) + + # fluid grad T + fluid_grad_temperature = grad_t_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + time=time, quadrature_tag=quadrature_tag, dd=dd_vol_fluid, + operator_states_quad=fluid_operator_states_quad, + comm_tag=_FluidGradTempTag + ) + + # sample grad CV + sample_grad_cv = grad_cv_operator( + dcoll, gas_model_sample, sample_all_boundaries_no_grad, sample_state, + time=time, quadrature_tag=quadrature_tag, dd=dd_vol_sample, + operator_states_quad=sample_operator_states_quad, + comm_tag=_SampleGradCVTag + ) + + # sample grad T + sample_grad_temperature = grad_t_operator( + dcoll, gas_model_sample, sample_all_boundaries_no_grad, sample_state, + time=time, quadrature_tag=quadrature_tag, dd=dd_vol_sample, + operator_states_quad=sample_operator_states_quad, + comm_tag=_SampleGradTempTag + ) + + # holder grad T + holder_grad_temperature = wall_grad_t_operator( + dcoll, holder_state.dv.thermal_conductivity, + holder_all_boundaries_no_grad, holder_state.dv.temperature, + quadrature_tag=quadrature_tag, dd=dd_vol_holder, + comm_tag=_HolderGradTempTag + ) + + # ~~~~~~~~~~~~~~~~~ + + fluid_all_boundaries, sample_all_boundaries = \ + add_multiphysics_interface_boundaries( + dcoll, dd_vol_fluid, dd_vol_sample, + gas_model_fluid, gas_model_sample, + fluid_state, sample_state, + fluid_grad_cv, sample_grad_cv, + fluid_grad_temperature, sample_grad_temperature, + fluid_boundaries, sample_boundaries, + limiter_func_fluid=_limit_fluid_cv, + limiter_func_wall=_limit_sample_cv, + interface_noslip=True, interface_radiation=use_radiation, + wall_emissivity=emissivity, sigma=5.67e-8, + ambient_temperature=300.0, + wall_penalty_amount=wall_penalty_amount) + + fluid_all_boundaries, holder_all_boundaries = \ + add_thermal_interface_boundaries( + dcoll, gas_model_fluid, # FIXME remove gas_model + dd_vol_fluid, dd_vol_holder, + fluid_state, holder_state.dv.thermal_conductivity, + holder_state.dv.temperature, + fluid_grad_temperature, holder_grad_temperature, + fluid_all_boundaries, holder_boundaries, + interface_noslip=True, interface_radiation=use_radiation, + wall_emissivity=emissivity, sigma=5.67e-8, + ambient_temperature=300.0, + wall_penalty_amount=wall_penalty_amount) + + sample_all_boundaries, holder_all_boundaries = \ + add_thermal_interface_boundaries( + dcoll, gas_model_sample, # FIXME remove gas_model + dd_vol_sample, dd_vol_holder, + sample_state, holder_state.dv.thermal_conductivity, + holder_state.dv.temperature, + sample_grad_temperature, holder_grad_temperature, + sample_all_boundaries, holder_all_boundaries, + interface_noslip=True, interface_radiation=False, + wall_penalty_amount=wall_penalty_amount) + + # ~~~~~~~~~~~~~ + + fluid_rhs = ns_operator( + dcoll, gas_model_fluid, fluid_state, fluid_all_boundaries, + time=time, quadrature_tag=quadrature_tag, dd=dd_vol_fluid, + operator_states_quad=fluid_operator_states_quad, + grad_cv=fluid_grad_cv, grad_t=fluid_grad_temperature, + comm_tag=_FluidOperatorTag, inviscid_terms_on=False) + + sample_rhs = ns_operator( + dcoll, gas_model_sample, sample_state, sample_all_boundaries, + time=time, quadrature_tag=quadrature_tag, dd=dd_vol_sample, + operator_states_quad=sample_operator_states_quad, + grad_cv=sample_grad_cv, grad_t=sample_grad_temperature, + comm_tag=_SampleOperatorTag, inviscid_terms_on=False) + + holder_energy_rhs = diffusion_operator( + dcoll, holder_state.dv.thermal_conductivity, holder_all_boundaries, + holder_state.dv.temperature, + penalty_amount=wall_penalty_amount, quadrature_tag=quadrature_tag, + dd=dd_vol_holder, grad_u=holder_grad_temperature, + comm_tag=_HolderOperatorTag) + + holder_rhs = SolidWallConservedVars( + mass=actx.np.zeros_like(holder_state.dv.temperature), + energy=holder_energy_rhs) + + sample_mass_rhs = sample_zeros + + # ~~~~~~~~~~~~~ + return make_obj_array([ + fluid_rhs, fluid_zeros, + sample_rhs, sample_zeros, sample_mass_rhs, + holder_rhs]) + + def my_post_step(step, t, dt, state): + if step == first_step + 1: + with gc_timer.start_sub_timer(): + gc.collect() + # Freeze the objects that are still alive so they will not + # be considered in future gc collections. + logger.info("Freezing GC objects to reduce overhead of " + "future GC collections") + gc.freeze() + + if logmgr: + set_dt(logmgr, dt) + logmgr.tick_after() + + return state, dt + +############################################################################## + + stepper_state = make_obj_array([ + fluid_state.cv, fluid_state.temperature, sample_state.cv, + sample_state.temperature, sample_state.wv.material_densities, + holder_state.cv + ]) + + dt = 1.0*current_dt + t = 1.0*current_t + + if rank == 0: + logging.info("Stepping.") + + final_step, final_t, stepper_state = \ + advance_state(rhs=my_rhs, timestepper=timestepper, + pre_step_callback=my_pre_step, + post_step_callback=my_post_step, + istep=current_step, dt=dt, t=t, t_final=t_final, + force_eval=force_eval, state=stepper_state) + + # Dump the final data + if rank == 0: + logger.info("Checkpointing final state ...") + + fluid_cv, fluid_tseed, \ + sample_cv, sample_tseed, sample_density, \ + holder_cv = stepper_state + + fluid_state = get_fluid_state(fluid_cv, fluid_tseed) + sample_state = get_sample_state(sample_cv, sample_density, sample_tseed) + holder_state = get_holder_state(holder_cv) + + my_write_viz(step=final_step, t=final_t, dt=dt, fluid_state=fluid_state, + sample_state=sample_state, holder_state=holder_state) + + my_write_restart(final_step, final_t, fluid_state, sample_state, + holder_state) + + if logmgr: + logmgr.close() + elif use_profiling: + print(actx.tabulate_profiling_data()) + + sys.exit() + + +if __name__ == "__main__": + logging.basicConfig(format="%(message)s", level=logging.INFO) + + import argparse + parser = argparse.ArgumentParser(description="MIRGE-Com 1D Flame Driver") + parser.add_argument("-r", "--restart_file", type=ascii, + dest="restart_file", nargs="?", action="store", + help="simulation restart file") + parser.add_argument("-i", "--input_file", type=ascii, + dest="input_file", nargs="?", action="store", + help="simulation config file") + parser.add_argument("-c", "--casename", type=ascii, + dest="casename", nargs="?", action="store", + help="simulation case name") + parser.add_argument("--profiling", action="store_true", default=False, + help="enable kernel profiling [OFF]") + parser.add_argument("--log", action="store_true", default=True, + help="enable logging profiling [ON]") + parser.add_argument("--lazy", action="store_true", default=False, + help="enable lazy evaluation [OFF]") + parser.add_argument("--numpy", action="store_true", + help="use numpy-based eager actx.") + parser.add_argument("--esdg", action="store_true", + help="use flux-differencing/entropy stable DG for inviscid computations.") + + args = parser.parse_args() + + warn("Automatically turning off DV logging. MIRGE-Com Issue(578)") + lazy = args.lazy + if args.profiling: + if lazy: + raise ValueError("Can't use lazy and profiling together.") + + from mirgecom.array_context import get_reasonable_array_context_class + actx_class = get_reasonable_array_context_class( + lazy=args.lazy, distributed=True, profiling=args.profiling, numpy=args.numpy) + + # for writing output + casename = "coupled_volumes" + if args.casename: + print(f"Custom casename {args.casename}") + casename = (args.casename).replace("'", "") + else: + print(f"Default casename {casename}") + + restart_file = None + if args.restart_file: + restart_file = (args.restart_file).replace("'", "") + print(f"Restarting from file: {restart_file}") + + input_file = None + if args.input_file: + input_file = (args.input_file).replace("'", "") + print(f"Reading user input from {args.input_file}") + else: + print("No user input file, using default values") + + print(f"Running {sys.argv[0]}\n") + + main(actx_class, use_logmgr=args.log, casename=casename, + restart_filename=restart_file) diff --git a/examples/multiple-volumes.py b/examples/multiple-volumes.py deleted file mode 100644 index 8feed339a..000000000 --- a/examples/multiple-volumes.py +++ /dev/null @@ -1,421 +0,0 @@ -""" -Demonstrate multiple non-interacting volumes. - -Runs several acoustic pulse simulations with different pulse amplitudes -simultaneously. -""" - -__copyright__ = """ -Copyright (C) 2020 University of Illinois Board of Trustees -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -import logging -from mirgecom.mpi import mpi_entry_point -import numpy as np -from functools import partial -from pytools.obj_array import make_obj_array - -from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa -from grudge.shortcuts import make_visualizer -from grudge.dof_desc import VolumeDomainTag, DISCR_TAG_BASE, DISCR_TAG_QUAD, DOFDesc - -from mirgecom.discretization import create_discretization_collection -from mirgecom.euler import ( - euler_operator, - extract_vars_for_logging -) -from mirgecom.simutil import ( - get_sim_timestep, - generate_and_distribute_mesh -) -from mirgecom.io import make_init_message - -from mirgecom.integrators import rk4_step -from mirgecom.steppers import advance_state -from mirgecom.boundary import AdiabaticSlipBoundary -from mirgecom.initializers import ( - Lump, - AcousticPulse -) -from mirgecom.eos import IdealSingleGas -from mirgecom.gas_model import ( - GasModel, - make_fluid_state -) -from logpyle import IntervalTimer, set_dt -from mirgecom.logging_quantities import ( - initialize_logmgr, - logmgr_add_many_discretization_quantities, - logmgr_add_cl_device_info, - logmgr_add_device_memory_usage, - set_sim_state -) - -logger = logging.getLogger(__name__) - - -class MyRuntimeError(RuntimeError): - """Simple exception to kill the simulation.""" - - pass - - -@mpi_entry_point -def main(actx_class, use_esdg=False, - use_overintegration=False, use_leap=False, - casename=None, rst_filename=None): - """Drive the example.""" - if casename is None: - casename = "mirgecom" - - from mpi4py import MPI - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - num_parts = comm.Get_size() - - from mirgecom.simutil import global_reduce as _global_reduce - global_reduce = partial(_global_reduce, comm=comm) - - logmgr = initialize_logmgr(True, - filename=f"{casename}.sqlite", mode="wu", mpi_comm=comm) - - from mirgecom.array_context import initialize_actx, actx_class_is_profiling - actx = initialize_actx(actx_class, comm) - queue = getattr(actx, "queue", None) - use_profiling = actx_class_is_profiling(actx_class) - - # timestepping control - current_step = 0 - if use_leap: - from leap.rk import RK4MethodBuilder - timestepper = RK4MethodBuilder("state") - else: - timestepper = rk4_step - t_final = 0.1 - current_cfl = 1.0 - current_dt = .005 - current_t = 0 - constant_cfl = False - - # some i/o frequencies - nstatus = 1 - nrestart = 5 - nviz = 100 - nhealth = 1 - - dim = 2 - - # Run simulations with several different pulse amplitudes simultaneously - pulse_amplitudes = [0.01, 0.1, 1.0] - nvolumes = len(pulse_amplitudes) - - rst_path = "restart_data/" - rst_pattern = ( - rst_path + "{cname}-{step:04d}-{rank:04d}.pkl" - ) - if rst_filename: # read the grid from restart data - rst_filename = f"{rst_filename}-{rank:04d}.pkl" - from mirgecom.restart import read_restart_data - restart_data = read_restart_data(actx, rst_filename) - local_prototype_mesh = restart_data["local_prototype_mesh"] - global_prototype_nelements = restart_data["global_prototype_nelements"] - assert restart_data["num_parts"] == num_parts - else: # generate the grids from scratch - from meshmode.mesh.generation import generate_regular_rect_mesh - generate_mesh = partial(generate_regular_rect_mesh, - a=(-1,)*dim, b=(1,)*dim, nelements_per_axis=(16,)*dim) - local_prototype_mesh, global_prototype_nelements = \ - generate_and_distribute_mesh(comm, generate_mesh) - - volume_to_local_mesh = {i: local_prototype_mesh for i in range(nvolumes)} - - local_nelements = local_prototype_mesh.nelements * nvolumes - global_nelements = global_prototype_nelements * nvolumes - - order = 3 - dcoll = create_discretization_collection(actx, volume_to_local_mesh, order=order) - - volume_dds = [ - DOFDesc(VolumeDomainTag(i), DISCR_TAG_BASE) - for i in range(nvolumes)] - - if use_overintegration: - quadrature_tag = DISCR_TAG_QUAD - else: - quadrature_tag = None - - vis_timer = None - - if logmgr: - logmgr_add_cl_device_info(logmgr, queue) - logmgr_add_device_memory_usage(logmgr, queue) - - def extract_vars(i, dim, cvs, eos): - name_to_field = extract_vars_for_logging(dim, cvs[i], eos) - return { - name + f"_{i}": field - for name, field in name_to_field.items()} - - def units(quantity): - return "" - - for i in range(nvolumes): - logmgr_add_many_discretization_quantities( - logmgr, dcoll, dim, partial(extract_vars, i), units, - dd=volume_dds[i]) - - vis_timer = IntervalTimer("t_vis", "Time spent visualizing") - logmgr.add_quantity(vis_timer) - - logmgr.add_watches([ - ("step.max", "step = {value}, "), - ("t_sim.max", "sim time: {value:1.6e} s\n"), - ("t_step.max", "------- step walltime: {value:6g} s, "), - ("t_log.max", "log walltime: {value:6g} s\n") - ]) - - for i in range(nvolumes): - logmgr.add_watches([ - (f"min_pressure_{i}", "------- P (vol. " + str(i) - + ") (min, max) (Pa) = ({value:1.9e}, "), - (f"max_pressure_{i}", "{value:1.9e})\n"), - ]) - - eos = IdealSingleGas() - gas_model = GasModel(eos=eos) - wall = AdiabaticSlipBoundary() - if rst_filename: - current_t = restart_data["t"] - current_step = restart_data["step"] - current_cvs = restart_data["cvs"] - if logmgr: - from mirgecom.logging_quantities import logmgr_set_time - logmgr_set_time(logmgr, current_step, current_t) - else: - # Set the current state from time 0 - def init(nodes, pulse_amplitude): - vel = np.zeros(shape=(dim,)) - orig = np.zeros(shape=(dim,)) - background = Lump( - dim=dim, center=orig, velocity=vel, rhoamp=0.0)(nodes) - return AcousticPulse( - dim=dim, - amplitude=pulse_amplitude, - width=0.1, - center=orig)(x_vec=nodes, cv=background, eos=eos) - current_cvs = make_obj_array([ - init(actx.thaw(dcoll.nodes(dd)), pulse_amplitude) - for dd, pulse_amplitude in zip(volume_dds, pulse_amplitudes)]) - - current_fluid_states = [make_fluid_state(cv, gas_model) for cv in current_cvs] - - visualizers = [make_visualizer(dcoll, volume_dd=dd) for dd in volume_dds] - - initname = "multiple-volumes" - eosname = eos.__class__.__name__ - init_message = make_init_message(dim=dim, order=order, - nelements=local_nelements, - global_nelements=global_nelements, - dt=current_dt, t_final=t_final, nstatus=nstatus, - nviz=nviz, cfl=current_cfl, - constant_cfl=constant_cfl, initname=initname, - eosname=eosname, casename=casename) - if rank == 0: - logger.info(init_message) - - def my_get_timestep(step, t, dt, fluid_states): - return min([ - get_sim_timestep( - dcoll, fluid_state, t, dt, current_cfl, t_final, constant_cfl) - for fluid_state in fluid_states]) - - def my_write_viz(step, t, cvs, dvs=None): - if dvs is None: - dvs = [eos.dependent_vars(cv) for cv in cvs] - for i in range(nvolumes): - viz_fields = [ - ("cv", cvs[i]), - ("dv", dvs[i])] - from mirgecom.simutil import write_visfile - write_visfile( - dcoll, viz_fields, visualizers[i], vizname=casename + f"-{i}", - step=step, t=t, overwrite=True, vis_timer=vis_timer, comm=comm) - - def my_write_restart(step, t, cvs): - rst_fname = rst_pattern.format(cname=casename, step=step, rank=rank) - if rst_fname != rst_filename: - rst_data = { - "local_prototype_mesh": local_prototype_mesh, - "cvs": cvs, - "t": t, - "step": step, - "order": order, - "global_nelements": global_nelements, - "num_parts": num_parts - } - from mirgecom.restart import write_restart_file - write_restart_file(actx, rst_data, rst_fname, comm) - - def my_health_check(pressures): - health_error = False - for dd, pressure in zip(volume_dds, pressures): - from mirgecom.simutil import check_naninf_local, check_range_local - if check_naninf_local(dcoll, dd, pressure) \ - or check_range_local(dcoll, dd, pressure, 1e-2, 10): - health_error = True - logger.info(f"{rank=}: Invalid pressure data found.") - break - return health_error - - def my_pre_step(step, t, dt, state): - cvs = state - fluid_states = [make_fluid_state(cv, gas_model) for cv in cvs] - dvs = [fluid_state.dv for fluid_state in fluid_states] - - try: - - if logmgr: - logmgr.tick_before() - - from mirgecom.simutil import check_step - do_viz = check_step(step=step, interval=nviz) - do_restart = check_step(step=step, interval=nrestart) - do_health = check_step(step=step, interval=nhealth) - - if do_health: - pressures = [dv.pressure for dv in dvs] - health_errors = global_reduce(my_health_check(pressures), op="lor") - if health_errors: - if rank == 0: - logger.info("Fluid solution failed health check.") - raise MyRuntimeError("Failed simulation health check.") - - if do_restart: - my_write_restart(step=step, t=t, cvs=cvs) - - if do_viz: - my_write_viz(step=step, t=t, cvs=cvs, dvs=dvs) - - except MyRuntimeError: - if rank == 0: - logger.info("Errors detected; attempting graceful exit.") - my_write_viz(step=step, t=t, cvs=cvs) - my_write_restart(step=step, t=t, cvs=cvs) - raise - - dt = my_get_timestep(step=step, t=t, dt=dt, fluid_states=fluid_states) - - return cvs, dt - - def my_post_step(step, t, dt, state): - # Logmgr needs to know about EOS, dt, dim? - # imo this is a design/scope flaw - if logmgr: - set_dt(logmgr, dt) - set_sim_state(logmgr, dim, state, eos) - logmgr.tick_after() - return state, dt - - def my_rhs(t, state): - cvs = state - fluid_states = [make_fluid_state(cv, gas_model) for cv in cvs] - return make_obj_array([ - euler_operator( - dcoll, state=fluid_state, time=t, - boundaries={dd.trace(BTAG_ALL).domain_tag: wall}, - gas_model=gas_model, quadrature_tag=quadrature_tag, - dd=dd, comm_tag=dd, use_esdg=use_esdg) - for dd, fluid_state in zip(volume_dds, fluid_states)]) - - current_dt = my_get_timestep( - current_step, current_t, current_dt, current_fluid_states) - - current_step, current_t, current_cvs = \ - advance_state(rhs=my_rhs, timestepper=timestepper, - pre_step_callback=my_pre_step, - post_step_callback=my_post_step, dt=current_dt, - state=current_cvs, t=current_t, t_final=t_final) - - # Dump the final data - if rank == 0: - logger.info("Checkpointing final state ...") - final_fluid_states = [make_fluid_state(cv, gas_model) for cv in current_cvs] - final_dvs = [fluid_state.dv for fluid_state in final_fluid_states] - - my_write_viz(step=current_step, t=current_t, cvs=current_cvs, dvs=final_dvs) - my_write_restart(step=current_step, t=current_t, cvs=current_cvs) - - if logmgr: - logmgr.close() - elif use_profiling: - print(actx.tabulate_profiling_data()) - - finish_tol = 1e-16 - assert np.abs(current_t - t_final) < finish_tol - - -if __name__ == "__main__": - import argparse - casename = "multiple-volumes" - parser = argparse.ArgumentParser(description=f"MIRGE-Com Example: {casename}") - parser.add_argument("--overintegration", action="store_true", - help="use overintegration in the RHS computations") - parser.add_argument("--lazy", action="store_true", - help="switch to a lazy computation mode") - parser.add_argument("--profiling", action="store_true", - help="turn on detailed performance profiling") - parser.add_argument("--esdg", action="store_true", - help="use entropy-stable DG for inviscid terms") - parser.add_argument("--leap", action="store_true", - help="use leap timestepper") - parser.add_argument("--numpy", action="store_true", - help="use numpy-based eager actx.") - parser.add_argument("--restart_file", help="root name of restart file") - parser.add_argument("--casename", help="casename to use for i/o") - args = parser.parse_args() - - from warnings import warn - from mirgecom.simutil import ApplicationOptionsError - if args.esdg: - if not args.lazy and not args.numpy: - raise ApplicationOptionsError("ESDG requires lazy or numpy context.") - if not args.overintegration: - warn("ESDG requires overintegration, enabling --overintegration.") - - from mirgecom.array_context import get_reasonable_array_context_class - actx_class = get_reasonable_array_context_class( - lazy=args.lazy, distributed=True, profiling=args.profiling, numpy=args.numpy) - - logging.basicConfig(format="%(message)s", level=logging.INFO) - if args.casename: - casename = args.casename - rst_filename = None - if args.restart_file: - rst_filename = args.restart_file - - main(actx_class, - use_leap=args.leap, use_esdg=args.esdg, - use_overintegration=args.overintegration or args.esdg, - casename=casename, rst_filename=rst_filename) - -# vim: foldmethod=marker diff --git a/mirgecom/materials/carbon_fiber.py b/mirgecom/materials/carbon_fiber.py index 5a614672f..94aa91c87 100644 --- a/mirgecom/materials/carbon_fiber.py +++ b/mirgecom/materials/carbon_fiber.py @@ -106,6 +106,71 @@ def get_source_terms(self, temperature, tau, rhoY_o2) -> DOFArray: # noqa N803 return (mw_co/mw_o2 + mw_o/mw_o2 - 1)*rhoY_o2*k*eff_surf_area +# TODO per MTC review, can we generalize the oxidation model? +# should we keep this in the driver? +class Y3_Oxidation_Model(Oxidation): # noqa N801 + r"""Evaluate the source terms for the Y3 model of carbon fiber oxidation. + + Follows ``A. Martin, AIAA 2013-2636'', using a single reaction given by + .. math:: + C_{(s)} + O_2 \to CO_2 + + .. automethod:: get_source_terms + """ + + def __init__(self, wall_material): + self._material = wall_material + + def _get_wall_effective_surface_area_fiber(self, tau) -> DOFArray: + r"""Evaluate the effective surface of the fibers. + + The fiber radius as a function of mass loss $\tau$ is given by + .. math:: + \tau = \frac{m}{m_0} = \frac{\pi r^2/L}{\pi r_0^2/L} = \frac{r^2}{r_0^2} + """ + actx = tau.array_context + + original_fiber_radius = 5e-6 # half the diameter + fiber_radius = original_fiber_radius*actx.np.sqrt(tau) + + epsilon_0 = self._material.volume_fraction(tau=1.0) + return 2.0*epsilon_0/original_fiber_radius**2*fiber_radius + + def get_source_terms(self, temperature, tau, rhoY_o2): # noqa N803 + r"""Return the effective source terms for the oxidation. + + Parameters + ---------- + temperature: meshmode.dof_array.DOFArray + tau: meshmode.dof_array.DOFArray + the progress ratio of the oxidation + ox_mass: meshmode.dof_array.DOFArray + the mass fraction of oxygen + + Returns + ------- + meshmode.dof_array.DOFArray + The tuple (\omega_{C}, \omega_{O_2}, \omega_{CO_2}) + """ + actx = temperature.array_context + + mw_c = 12.011 + mw_o = 15.999 + mw_o2 = mw_o*2 + mw_co2 = 44.010 + univ_gas_const = 8.31446261815324 # J/(K-mol) + + eff_surf_area = self._get_wall_effective_surface_area_fiber(tau) + + k_f = 1.0e5*actx.np.exp(-120000.0/(univ_gas_const*temperature)) + + m_dot_c = - rhoY_o2/mw_o2 * mw_c * eff_surf_area * k_f + m_dot_o2 = - rhoY_o2/mw_o2 * mw_o2 * eff_surf_area * k_f + m_dot_co2 = + rhoY_o2/mw_o2 * mw_co2 * eff_surf_area * k_f + + return m_dot_c, m_dot_o2, m_dot_co2 + + class FiberEOS(PorousWallEOS): """Evaluate the properties of the solid state containing only fibers. diff --git a/mirgecom/multiphysics/__init__.py b/mirgecom/multiphysics/__init__.py index c37b7d9db..9b911749b 100644 --- a/mirgecom/multiphysics/__init__.py +++ b/mirgecom/multiphysics/__init__.py @@ -27,6 +27,7 @@ __doc__ = """ .. automodule:: mirgecom.multiphysics.thermally_coupled_fluid_wall .. automodule:: mirgecom.multiphysics.phenolics_coupled_fluid_wall +.. automodule:: mirgecom.multiphysics.multiphysics_coupled_fluid_wall .. autofunction:: make_interface_boundaries """ diff --git a/mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py b/mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py new file mode 100644 index 000000000..6a510b6cb --- /dev/null +++ b/mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py @@ -0,0 +1,1046 @@ +r""":mod:`mirgecom.multiphysics.multiphysics_coupled_fluid_wall` for +fully-coupled fluid and wall. + +Couples a fluid subdomain governed by the compressible Navier-Stokes equations +with a wall subdomain governed by the porous media equation by enforcing +continuity of quantities and their respective fluxes + +.. math:: + q_\text{fluid} &= q_\text{wall} \\ + - D_\text{fluid} \nabla q_\text{fluid} \cdot \hat{n} &= + - D_\text{wall} \nabla q_\text{wall} \cdot \hat{n}. + +at the interface. + +.. autofunction:: add_interface_boundaries_no_grad +.. autofunction:: add_interface_boundaries +.. autoclass:: InterfaceFluidBoundary +.. autoclass:: InterfaceWallBoundary +""" + +__copyright__ = """ +Copyright (C) 2023 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np + +from grudge.trace_pair import ( + inter_volume_trace_pairs, + TracePair +) +from grudge.dof_desc import as_dofdesc + +from mirgecom.fluid import make_conserved +from mirgecom.fluid import species_mass_fraction_gradient +from mirgecom.math import harmonic_mean +from mirgecom.transport import GasTransportVars +from mirgecom.boundary import MengaldoBoundaryCondition +# from mirgecom.viscous import ( +# viscous_stress_tensor, +# ) +from mirgecom.gas_model import ( + make_fluid_state, + ViscousFluidState +) +from mirgecom.diffusion import diffusion_flux +from mirgecom.utils import project_from_base +from mirgecom.wall_model import PorousFlowModel + + +class _CVInterVolTag: + pass + + +class _TemperatureInterVolTag: + pass + + +class _MatDensityInterVolTag: + pass + + +class _GradCVInterVolTag: + pass + + +class _GradTemperatureInterVolTag: + pass + + +class _MultiphysicsCoupledHarmonicMeanBoundaryComponent: + """Setup of the coupling between both sides of the interface.""" + + def __init__(self, state_plus, interface_noslip, interface_radiation, + boundary_velocity=None, grad_cv_plus=None, grad_t_plus=None): + r"""Initialize coupling interface. + + Arguments *grad_cv_plus* and *grad_t_plus*, are only required if the + boundary will be used to compute the viscous flux. + + Parameters + ---------- + state_plus: :class:`~mirgecom.gas_model.FluidState` + Fluid state on either wall or fluid side. + + interface_noslip: bool + If `True`, interface boundaries on the fluid side will be treated + as no-slip walls. If `False` they will be treated as slip walls. + + interface_radiation: bool + If `True`, radiation is accounted for as a sink term in the coupling. + If `False` they will be treated as slip walls. + + boundary_velocity: float + If there is a normal velocity prescribed at the boundary. + + grad_cv_plus: :class:`meshmode.dof_array.DOFArray` or None + CV gradient from the wall side. + + grad_t_plus: :class:`meshmode.dof_array.DOFArray` or None + Temperature gradient from the wall side. + """ + self._state_plus = state_plus + self._no_slip = interface_noslip + self._radiation = interface_radiation + self._boundary_velocity = boundary_velocity + self._grad_cv_plus = grad_cv_plus + self._grad_t_plus = grad_t_plus + + def state_plus(self, dcoll, dd_bdry, gas_model, state_minus, **kwargs): + """State to enforce inviscid BC at the interface.""" + # This is only used for inviscid flux, so I dont think I have to exactly + # use the plus side but rather a state that enforces the BC. Thus, + # a "no-slip wall"-like implementation is used. + # For the viscous fluxes/gradients, the actual plus state is used. + if self._boundary_velocity is not None: + # actx = state_minus.cv.mass.array_context + # normal = actx.thaw(dcoll.normal(dd_bdry)) + # momentum_plus = (2.0*state_minus.cv.mass*self._boundary_velocity*normal + # - state_minus.cv.momentum) + + # cv_plus = make_conserved(dim=dcoll.dim, + # mass=state_minus.cv.mass, + # energy=state_minus.cv.energy, + # momentum=momentum_plus, + # species_mass=state_minus.cv.species_mass) + + # return ViscousFluidState(cv=cv_plus, dv=state_minus.dv, + # tv=state_minus.tv) + raise NotImplementedError + + if self._no_slip is True: + # use the same implementation from no-slip walls + cv_plus = make_conserved(dim=dcoll.dim, + mass=state_minus.cv.mass, + energy=state_minus.cv.energy, + momentum=-state_minus.cv.momentum, + species_mass=state_minus.cv.species_mass) + + return ViscousFluidState(cv=cv_plus, dv=state_minus.dv, + tv=state_minus.tv) + + raise NotImplementedError + + def state_bc(self, dcoll, dd_bdry, gas_model, state_minus, **kwargs): + """State to enforce viscous BC at the interface.""" + actx = state_minus.array_context + + state_plus = project_from_base(dcoll, dd_bdry, self._state_plus) + + u_bc = self.velocity_bc(dcoll, dd_bdry, state_minus) + t_bc = self.temperature_bc(dcoll, dd_bdry, state_minus) + y_bc = self.species_mass_fractions_bc(dcoll, dd_bdry, state_minus) + + material_densities = state_minus.wv.material_densities \ + if isinstance(gas_model, PorousFlowModel) else None + + # the gradient that matters is the intrinsic density, not the bulk one. + # thus, has to consider the presence of 'epsilon' to avoid jump in mass. + if isinstance(gas_model, PorousFlowModel): + # wall side, where the plus is the fluid + epsilon_plus = state_minus.wv.void_fraction + mass_bc = 0.5*(state_minus.mass_density + + state_plus.mass_density*epsilon_plus) + + tau = gas_model.decomposition_progress(material_densities) + from mirgecom.wall_model import PorousWallVars + from mirgecom.gas_model import PorousFlowFluidState + wv = PorousWallVars( + material_densities=material_densities, + tau=tau, + density=gas_model.solid_density(material_densities), + void_fraction=gas_model.wall_eos.void_fraction(tau), + emissivity=gas_model.wall_eos.emissivity(tau), + permeability=gas_model.wall_eos.permeability(tau), + tortuosity=gas_model.wall_eos.tortuosity(tau)) + + total_energy_bc = mass_bc*gas_model.eos.get_internal_energy(t_bc, y_bc) \ + + wv.density*gas_model.wall_eos.enthalpy(t_bc, tau) + + smoothness_mu = actx.np.zeros_like(state_minus.cv.mass) + smoothness_kappa = actx.np.zeros_like(state_minus.cv.mass) + smoothness_beta = actx.np.zeros_like(state_minus.cv.mass) + + else: + # fluid side, where the plus is the wall + epsilon_plus = state_plus.wv.void_fraction + mass_bc = 0.5*(state_minus.mass_density + + state_plus.mass_density/epsilon_plus) + + internal_energy_bc = gas_model.eos.get_internal_energy( + temperature=t_bc, species_mass_fractions=y_bc) + total_energy_bc = mass_bc*(internal_energy_bc + 0.5*np.dot(u_bc, u_bc)) + + smoothness_mu = state_minus.dv.smoothness_mu + smoothness_kappa = state_minus.dv.smoothness_kappa + smoothness_beta = state_minus.dv.smoothness_beta + + cv_bc = make_conserved(dim=dcoll.dim, + mass=mass_bc, + momentum=mass_bc*u_bc, + energy=total_energy_bc, + species_mass=mass_bc*y_bc) + + state_bc = make_fluid_state(cv=cv_bc, gas_model=gas_model, + temperature_seed=t_bc, + smoothness_mu=smoothness_mu, + smoothness_kappa=smoothness_kappa, + smoothness_beta=smoothness_beta, + material_densities=material_densities) + + new_mu = state_minus.tv.viscosity + + new_kappa = state_minus.tv.thermal_conductivity if self._radiation else \ + harmonic_mean(state_minus.tv.thermal_conductivity, + state_plus.tv.thermal_conductivity) + + new_diff = harmonic_mean(state_minus.tv.species_diffusivity, + state_plus.tv.species_diffusivity) + + new_tv = GasTransportVars( + bulk_viscosity=state_bc.tv.bulk_viscosity, + viscosity=new_mu, + thermal_conductivity=new_kappa, + species_diffusivity=new_diff) + + if isinstance(gas_model, PorousFlowModel): + return PorousFlowFluidState(cv=state_bc.cv, dv=state_bc.dv, + tv=new_tv, wv=wv) + + return ViscousFluidState(cv=state_bc.cv, dv=state_bc.dv, tv=new_tv) + + def velocity_bc(self, dcoll, dd_bdry, state_minus): + """Velocity at the interface. + + The velocity can be non-zero due to the blowing or zero if no-slip. + """ + u_minus = state_minus.cv.velocity + + # if there is mass blowing normal to the surface + if self._boundary_velocity is not None: + # actx = state_minus.cv.mass.array_context + # normal = actx.thaw(dcoll.normal(dd_bdry)) + # return self._boundary_velocity*normal + + raise NotImplementedError + + # if the coupling involves a no-slip wall: + if self._no_slip: + return u_minus*0.0 + + raise NotImplementedError + + def species_mass_fractions_bc(self, dcoll, dd_bdry, state_minus): + """Species mass fractions at the interface.""" + y_minus = state_minus.species_mass_fractions + y_plus = project_from_base(dcoll, dd_bdry, + self._state_plus.species_mass_fractions) + + actx = state_minus.array_context + diff_minus = state_minus.tv.species_diffusivity + diff_plus = project_from_base(dcoll, dd_bdry, + self._state_plus.tv.species_diffusivity) + diff_sum = diff_minus + diff_plus + # for cases with zero species diffusion + return actx.np.where( + actx.np.greater(diff_sum, 0.0), + (y_minus * diff_minus + y_plus * diff_plus)/diff_sum, + y_plus + y_minus) + + def temperature_bc(self, dcoll, dd_bdry, state_minus): + """Temperature at the interface.""" + t_minus = state_minus.temperature + t_plus = project_from_base(dcoll, dd_bdry, self._state_plus.temperature) + + kappa_minus = state_minus.tv.thermal_conductivity + kappa_plus = project_from_base(dcoll, dd_bdry, + self._state_plus.tv.thermal_conductivity) + kappa_sum = kappa_minus + kappa_plus + return (t_minus * kappa_minus + t_plus * kappa_plus)/kappa_sum + + def grad_cv_bc(self, dcoll, dd_bdry, gas_model, state_minus, grad_cv_minus, + normal, **kwargs): + """Gradient averaging for viscous flux.""" + if self._grad_cv_plus is None: + raise ValueError( + "Boundary does not have external CV gradient data.") + + grad_cv_plus = project_from_base(dcoll, dd_bdry, self._grad_cv_plus) + + # if the coupling involves a no-slip wall: + if self._no_slip: + grad_cv_bc = (grad_cv_plus + grad_cv_minus)/2 + return make_conserved(dim=dcoll.dim, + mass=grad_cv_bc.mass, + momentum=grad_cv_minus.momentum, + energy=grad_cv_bc.energy, + species_mass=grad_cv_bc.species_mass) + + raise NotImplementedError + + def grad_temperature_bc(self, dcoll, dd_bdry, grad_t_minus): + """Gradient averaging for viscous flux.""" + if self._grad_t_plus is None: + raise ValueError( + "Boundary does not have external temperature gradient data.") + + grad_t_plus = project_from_base(dcoll, dd_bdry, self._grad_t_plus) + return (grad_t_plus + grad_t_minus)/2 + + +class InterfaceFluidBoundary(MengaldoBoundaryCondition): + """Boundary for the fluid side on the interface between fluid and wall. + + .. automethod:: __init__ + .. automethod:: state_plus + .. automethod:: state_bc + .. automethod:: grad_cv_bc + .. automethod:: temperature_bc + .. automethod:: grad_temperature_bc + .. automethod:: viscous_divergence_flux + """ + + def __init__(self, state_plus, interface_noslip, interface_radiation, + boundary_velocity=None, + grad_cv_plus=None, grad_t_plus=None, + flux_penalty_amount=None, lengthscales_minus=None): + r"""Initialize InterfaceFluidBoundary. + + Arguments *grad_cv_plus*, *grad_t_plus*, *flux_penalty_amount*, and + *lengthscales_minus* are only required if the boundary will be used to + compute the viscous flux. + + Parameters + ---------- + state_plus: :class:`~mirgecom.gas_model.FluidState` + Fluid state from the wall side, i.e., porous flow. + + interface_noslip: bool + If `True`, interface boundaries on the fluid side will be treated + as no-slip walls. If `False` they will be treated as slip walls. + + interface_radiation: bool + If `True`, radiation is accounted for as a sink term in the coupling. + If `False` they will be treated as slip walls. + + boundary_velocity: float + If there is a prescribed velocity at the boundary. + + grad_cv_plus: :class:`meshmode.dof_array.DOFArray` or None + CV gradient from the wall side. + + grad_t_plus: :class:`meshmode.dof_array.DOFArray` or None + Temperature gradient from the wall side. + + flux_penalty_amount: float or None + Coefficient $c$ for the interior penalty on the heat flux. + + lengthscales_minus: :class:`meshmode.dof_array.DOFArray` or None + Characteristic mesh spacing $h^-$. + """ + self._state_plus = state_plus + self._radiation = interface_radiation + self._grad_cv_plus = grad_cv_plus + self._grad_t_plus = grad_t_plus + self._flux_penalty_amount = flux_penalty_amount + self._lengthscales_minus = lengthscales_minus + + self._coupled = _MultiphysicsCoupledHarmonicMeanBoundaryComponent( + state_plus=state_plus, + boundary_velocity=boundary_velocity, + interface_noslip=interface_noslip, + interface_radiation=interface_radiation, + grad_cv_plus=grad_cv_plus, + grad_t_plus=grad_t_plus) + + def state_plus(self, dcoll, dd_bdry, gas_model, state_minus, **kwargs): + """State to enforce inviscid BC at the interface.""" + # Don't bother replacing anything since this is just for inviscid + return self._coupled.state_plus(dcoll, dd_bdry, gas_model, state_minus, **kwargs) + + def state_bc(self, dcoll, dd_bdry, gas_model, state_minus, **kwargs): + """State to enforce viscous BC at the interface.""" + dd_bdry = as_dofdesc(dd_bdry) + return self._coupled.state_bc(dcoll, dd_bdry, gas_model, state_minus) + + def temperature_bc(self, dcoll, dd_bdry, state_minus, **kwargs): + r"""Interface temperature to enforce viscous BC.""" + if self._radiation: + t_plus = project_from_base(dcoll, dd_bdry, self._state_plus.temperature) + return 0.5*(t_plus + state_minus.temperature) + return self._coupled.temperature_bc(dcoll, dd_bdry, state_minus) + + def grad_cv_bc(self, dcoll, dd_bdry, gas_model, state_minus, grad_cv_minus, + normal, **kwargs): + """Gradient of CV to enforce viscous BC.""" + return self._coupled.grad_cv_bc(dcoll, dd_bdry, gas_model, state_minus, + grad_cv_minus, normal, **kwargs) + + def grad_temperature_bc(self, dcoll, dd_bdry, grad_t_minus, normal, **kwargs): + r"""Gradient of temperature to enforce viscous BC. + + If using radiation, uses $\nabla T_{bc} = \nabla T^{-}$. + Else, the simple averaging of gradient at both sides is used instead. + """ + if self._radiation: + return grad_t_minus + return self._coupled.grad_temperature_bc(dcoll, dd_bdry, grad_t_minus) + + def viscous_divergence_flux( + self, dcoll, dd_bdry, gas_model, state_minus, grad_cv_minus, + grad_t_minus, numerical_flux_func=None, **kwargs): + r"""Return the viscous flux at the interface boundaries. + + It is defined by + :meth:`mirgecom.boundary.MengaldoBoundaryCondition.viscous_divergence_flux` + + For radiation cases: + ..math:: + + \nabla T_{bc} = \nabla T^- + \kappa_{bc} = \kappa^- + """ + dd_bdry = as_dofdesc(dd_bdry) + + base_viscous_flux = super().viscous_divergence_flux( + dcoll=dcoll, dd_bdry=dd_bdry, gas_model=gas_model, + state_minus=state_minus, numerical_flux_func=numerical_flux_func, + grad_cv_minus=grad_cv_minus, grad_t_minus=grad_t_minus, **kwargs) + + penalization = 0.0 + + # state_plus = project_from_base(dcoll, dd_bdry, self._state_plus) + + # state_bc = self.state_bc(dcoll=dcoll, dd_bdry=dd_bdry, + # gas_model=gas_model, state_minus=state_minus, **kwargs) + + # lengthscales_minus = project_from_base( + # dcoll, dd_bdry, self._lengthscales_minus) + + # penalty = self._flux_penalty_amount/lengthscales_minus + # tau_momentum = penalty * state_bc.tv.viscosity + # tau_energy = penalty * state_bc.tv.thermal_conductivity + # tau_species = penalty * state_bc.tv.species_diffusivity + + # penalization = make_conserved(dim=dcoll.dim, + # mass=state_minus.cv.mass*0.0, + # energy=tau_energy*( + # state_plus.temperature - state_minus.temperature), + # momentum=tau_momentum*( + # state_plus.cv.momentum - state_minus.cv.momentum), + # species_mass=tau_species*( + # state_plus.cv.species_mass - state_minus.cv.species_mass) + # ) + + return base_viscous_flux + penalization + + +class InterfaceWallBoundary(MengaldoBoundaryCondition): + """Boundary for the wall side of the fluid-wall interface. + + .. automethod:: __init__ + .. automethod:: state_plus + .. automethod:: state_bc + .. automethod:: grad_cv_bc + .. automethod:: temperature_bc + .. automethod:: grad_temperature_bc + .. automethod:: viscous_divergence_flux + """ + + def __init__(self, state_plus, interface_noslip, interface_radiation, + wall_emissivity=None, sigma=None, u_ambient=None, + grad_cv_plus=None, grad_t_plus=None, + flux_penalty_amount=None, lengthscales_minus=None): + r"""Initialize InterfaceWallBoundary. + + Arguments *grad_cv_plus*, *grad_t_plus*, *flux_penalty_amount*, and + *lengthscales_minus* are only required if the boundary will be used to + compute the viscous flux. + + Parameters + ---------- + state_plus: :class:`~mirgecom.gas_model.FluidState` + Fluid state from the fluid side. + + interface_noslip: bool + If `True`, interface boundaries on the fluid side will be treated + as no-slip walls. If `False` they will be treated as slip walls. + + grad_cv_plus: :class:`meshmode.dof_array.DOFArray` or None + CV gradient from the fluid side. + + grad_t_plus: :class:`meshmode.dof_array.DOFArray` or None + Temperature gradient from the fluid side. + + flux_penalty_amount: float or None + Coefficient $c$ for the interior penalty on the viscous fluxes. + + lengthscales_minus: :class:`meshmode.dof_array.DOFArray` or None + Characteristic mesh spacing $h^-$. + """ + self._state_plus = state_plus + self._radiation = interface_radiation + self._grad_cv_plus = grad_cv_plus + self._grad_t_plus = grad_t_plus + self._emissivity = wall_emissivity + self._sigma = sigma + self._u_ambient = u_ambient + self._flux_penalty_amount = flux_penalty_amount + self._lengthscales_minus = lengthscales_minus + + self._coupled = _MultiphysicsCoupledHarmonicMeanBoundaryComponent( + state_plus=state_plus, + interface_noslip=interface_noslip, + interface_radiation=interface_radiation, + grad_cv_plus=grad_cv_plus, + grad_t_plus=grad_t_plus) + + def state_plus(self, dcoll, dd_bdry, gas_model, state_minus, **kwargs): + """State to enforce inviscid BC at the interface.""" + # Don't bother replacing anything since this is just for inviscid + return self._coupled.state_plus(dcoll, dd_bdry, state_minus, **kwargs) + + def state_bc(self, dcoll, dd_bdry, gas_model, state_minus, **kwargs): + """State to enforce viscous BC at the interface.""" + dd_bdry = as_dofdesc(dd_bdry) + return self._coupled.state_bc(dcoll, dd_bdry, gas_model, state_minus) + + def temperature_bc(self, dcoll, dd_bdry, state_minus, **kwargs): + """Interface temperature to enforce viscous BC. + + If using radiation, uses $T_{bc} = T^{+}$. Else, the simple averaging + of temperature at both sides is used instead. + """ + if self._radiation: + return project_from_base(dcoll, dd_bdry, self._state_plus.temperature) + return self._coupled.temperature_bc(dcoll, dd_bdry, state_minus) + + def grad_cv_bc(self, dcoll, dd_bdry, gas_model, state_minus, grad_cv_minus, + normal, **kwargs): + """Gradient of CV to enforce viscous BC.""" + return self._coupled.grad_cv_bc(dcoll, dd_bdry, gas_model, state_minus, + grad_cv_minus, normal, **kwargs) + + def grad_temperature_bc(self, dcoll, dd_bdry, grad_t_minus, normal, **kwargs): + """Gradient of temperature to enforce viscous BC.""" + if self._radiation: + return grad_t_minus + return self._coupled.grad_temperature_bc(dcoll, dd_bdry, grad_t_minus) + + def viscous_divergence_flux(self, dcoll, dd_bdry, gas_model, state_minus, + grad_cv_minus, grad_t_minus, numerical_flux_func=None, **kwargs): + """Return the viscous flux at the interface boundaries. + + It is defined by + :meth:`mirgecom.boundary.MengaldoBoundaryCondition.viscous_divergence_flux` + """ + dd_bdry = as_dofdesc(dd_bdry) + + state_plus = project_from_base(dcoll, dd_bdry, self._state_plus) + + state_bc = self.state_bc(dcoll=dcoll, dd_bdry=dd_bdry, gas_model=gas_model, + state_minus=state_minus, **kwargs) + + base_viscous_flux = super().viscous_divergence_flux( + dcoll=dcoll, dd_bdry=dd_bdry, gas_model=gas_model, + state_minus=state_minus, numerical_flux_func=numerical_flux_func, + grad_cv_minus=grad_cv_minus, grad_t_minus=grad_t_minus, **kwargs) + + # lengthscales_minus = project_from_base(dcoll, dd_bdry, + # self._lengthscales_minus) + + # penalty = self._flux_penalty_amount/lengthscales_minus + # tau_momentum = penalty * state_bc.tv.viscosity + # tau_energy = penalty * state_bc.tv.thermal_conductivity + # tau_species = penalty * state_bc.tv.species_diffusivity + + if self._radiation: + radiation_spec = [self._emissivity is None, + self._sigma is None, + self._u_ambient is None] + if sum(radiation_spec) != 0: + raise TypeError( + "Arguments 'wall_emissivity', 'sigma' and 'ambient_temperature'" + "are required if using surface radiation.") + + actx = state_minus.cv.mass.array_context + normal = actx.thaw(dcoll.normal(dd_bdry)) + + kappa_plus = state_plus.thermal_conductivity + grad_t_plus = project_from_base(dcoll, dd_bdry, self._grad_t_plus) + + # species flux + from arraycontext import outer + grad_y_minus = species_mass_fraction_gradient(state_minus.cv, + grad_cv_minus) + grad_y_plus = species_mass_fraction_gradient(state_plus.cv, + self._grad_cv_plus) + rho_bc = 0.5*(state_minus.cv.mass + state_plus.cv.mass) + grad_y_bc = 1.0/(state_minus.wv.void_fraction + 1.0)*( + state_minus.wv.void_fraction*grad_y_minus + grad_y_plus) + d_bc = state_bc.species_diffusivity + y_bc = state_bc.species_mass_fractions + + species_flux = -rho_bc*( + d_bc.reshape(-1, 1)*grad_y_bc + - outer(y_bc, sum(d_bc.reshape(-1, 1)*grad_y_bc))) + species_mass_flux = -species_flux@normal + + # heat flux due to shear, species and thermal conduction + # tau = viscous_stress_tensor(state_bc, grad_cv_bc) + h_alpha = state_bc.species_enthalpies + heat_flux = ( + # np.dot(tau, state_bc.velocity) + - sum(h_alpha.reshape(-1, 1) * species_flux) + - diffusion_flux(kappa_plus, grad_t_plus))@normal + + # radiation sink term + wall_emissivity = project_from_base(dcoll, dd_bdry, self._emissivity) + radiation = wall_emissivity * self._sigma * ( + state_minus.temperature**4 - self._u_ambient**4) + + penalization = 0.0 + # penalization = make_conserved(dim=dcoll.dim, + # mass=state_minus.cv.mass*0.0, + # energy=tau_energy*( + # state_plus.temperature - state_minus.temperature), + # momentum=tau_momentum*( + # state_plus.cv.momentum - state_minus.cv.momentum), + # species_mass=tau_species*( + # state_plus.cv.species_mass - state_minus.cv.species_mass) + # ) + + return (base_viscous_flux.replace(energy=heat_flux - radiation, + species_mass=species_mass_flux) + + penalization) + + else: + penalization = 0.0 + # penalization = make_conserved(dim=dcoll.dim, + # mass=state_minus.cv.mass*0.0, + # energy=tau_energy*( + # state_plus.temperature - state_minus.temperature), + # momentum=tau_momentum*( + # state_plus.cv.momentum - state_minus.cv.momentum), + # species_mass=tau_species*( + # state_plus.cv.species_mass - state_minus.cv.species_mass) + # ) + + return base_viscous_flux + penalization + + +def _getattr_ish(obj, name): + if obj is None: + return None + else: + return getattr(obj, name) + + +def _state_inter_volume_trace_pairs( + dcoll, fluid_dd, wall_dd, gas_model_fluid, gas_model_wall, + fluid_state, wall_state, limiter_func_fluid, limiter_func_wall): + """Exchange state across the fluid-wall interface.""" + actx = fluid_state.cv.mass.array_context + + # exchange CV + pairwise_cv = {(fluid_dd, wall_dd): + (fluid_state.cv, wall_state.cv)} + cv_pairs = inter_volume_trace_pairs( + dcoll, pairwise_cv, comm_tag=_CVInterVolTag) + + fluid_to_wall_cv_tpairs = cv_pairs[fluid_dd, wall_dd] + wall_to_fluid_cv_tpairs = cv_pairs[wall_dd, fluid_dd] + + # exchange temperature + pairwise_temp = {(fluid_dd, wall_dd): + (fluid_state.temperature, wall_state.temperature)} + temperature_seed_pairs = inter_volume_trace_pairs( + dcoll, pairwise_temp, comm_tag=_TemperatureInterVolTag) + + fluid_to_wall_tseed_tpairs = temperature_seed_pairs[fluid_dd, wall_dd] + wall_to_fluid_tseed_tpairs = temperature_seed_pairs[wall_dd, fluid_dd] + + # exchange material densities. It is zero on the fluid side... + from pytools.obj_array import make_obj_array + ncomponents = len(wall_state.wv.material_densities) + if ncomponents == 1: + zeros = actx.np.zeros_like(fluid_state.cv.mass) + else: + zeros = make_obj_array([actx.np.zeros_like(fluid_state.cv.mass) + for i in range(ncomponents)]) + pairwise_dens = {(fluid_dd, wall_dd): + (zeros, wall_state.wv.material_densities)} + material_densities_pairs = inter_volume_trace_pairs( + dcoll, pairwise_dens, comm_tag=_MatDensityInterVolTag) + + fluid_to_wall_mass_tpairs = material_densities_pairs[fluid_dd, wall_dd] + wall_to_fluid_mass_tpairs = material_densities_pairs[wall_dd, fluid_dd] + + return { + (fluid_dd, wall_dd): [TracePair( + cv_pair.dd, + interior=make_fluid_state( + cv_pair.int, gas_model_wall, + temperature_seed=_getattr_ish(tseed_pair, "int"), + material_densities=_getattr_ish(material_densities_pair, "int"), + limiter_func=limiter_func_wall, limiter_dd=cv_pair.dd), + exterior=make_fluid_state( + cv_pair.ext, gas_model_fluid, + temperature_seed=_getattr_ish(tseed_pair, "ext"), + material_densities=_getattr_ish(material_densities_pair, "ext"), + limiter_func=limiter_func_fluid, limiter_dd=cv_pair.dd)) + for cv_pair, tseed_pair, material_densities_pair in zip( + fluid_to_wall_cv_tpairs, + fluid_to_wall_tseed_tpairs, + fluid_to_wall_mass_tpairs)], + (wall_dd, fluid_dd): [TracePair( + cv_pair.dd, + interior=make_fluid_state( + cv_pair.int, gas_model_fluid, + temperature_seed=_getattr_ish(tseed_pair, "int"), + material_densities=_getattr_ish(material_densities_pair, "int"), + limiter_func=limiter_func_fluid, limiter_dd=cv_pair.dd), + exterior=make_fluid_state( + cv_pair.ext, gas_model_wall, + temperature_seed=_getattr_ish(tseed_pair, "ext"), + material_densities=_getattr_ish(material_densities_pair, "ext"), + limiter_func=limiter_func_wall, limiter_dd=cv_pair.dd)) + for cv_pair, tseed_pair, material_densities_pair in zip( + wall_to_fluid_cv_tpairs, + wall_to_fluid_tseed_tpairs, + wall_to_fluid_mass_tpairs)]} + + +def _grad_cv_inter_volume_trace_pairs( + dcoll, fluid_dd, wall_dd, fluid_grad_cv, wall_grad_cv): + """Exchange CV gradients across the fluid-wall interface.""" + pairwise_grad_cv = {(fluid_dd, wall_dd): (fluid_grad_cv, wall_grad_cv)} + return inter_volume_trace_pairs( + dcoll, pairwise_grad_cv, comm_tag=_GradCVInterVolTag) + + +def _grad_temperature_inter_volume_trace_pairs( + dcoll, fluid_dd, wall_dd, fluid_grad_temperature, wall_grad_temperature): + """Exchange temperature gradient across the fluid-wall interface.""" + pairwise_grad_temperature = { + (fluid_dd, wall_dd): + (fluid_grad_temperature, wall_grad_temperature)} + return inter_volume_trace_pairs( + dcoll, pairwise_grad_temperature, comm_tag=_GradTemperatureInterVolTag) + + +def add_interface_boundaries_no_grad( + dcoll, fluid_dd, wall_dd, + gas_model_fluid, gas_model_wall, + fluid_state, wall_state, + fluid_boundaries, wall_boundaries, + interface_noslip, interface_radiation, + *, + limiter_func_fluid=None, limiter_func_wall=None, + boundary_velocity=None): + r"""Return the interface of the subdomains for gradient calculation. + + Used to apply the boundary fluxes at the interface between fluid and + wall domains. + + Parameters + ---------- + dcoll: class:`~grudge.discretization.DiscretizationCollection` + A discretization collection encapsulating the DG elements + + fluid_dd: :class:`grudge.dof_desc.DOFDesc` + DOF descriptor for the fluid volume. + + wall_dd: :class:`grudge.dof_desc.DOFDesc` + DOF descriptor for the wall volume. + + fluid_boundaries: + Dictionary of boundary objects for the fluid subdomain, one for each + :class:`~grudge.dof_desc.BoundaryDomainTag` that represents a domain + boundary. + + wall_boundaries: + Dictionary of boundary objects for the wall subdomain, one for each + :class:`~grudge.dof_desc.BoundaryDomainTag` that represents a domain + boundary. + + fluid_state: :class:`~mirgecom.gas_model.FluidState` + Fluid state object with the conserved state and dependent + quantities for the fluid volume. + + wall_state: :class:`~mirgecom.gas_model.FluidState` + Wall state object with the conserved state and dependent + quantities for the wall volume. + + interface_noslip: bool + If `True`, interface boundaries on the fluid side will be treated as + no-slip walls. If `False` they will be treated as slip walls. + + interface_radiation: bool + If `True`, interface includes a radiation sink term in the heat flux + on the wall side and prescribes the temperature on the fluid side. + Additional arguments *wall_emissivity*, *sigma*, and + *ambient_temperature* are required if enabled. + + boundary_velocity: float or :class:`meshmode.dof_array.DOFArray` + Normal velocity due to pyrolysis outgas. Only required for simplified + analysis of composite material. + + Returns + ------- + The tuple `(fluid_interface_boundaries, wall_interface_boundaries)`. + """ + if interface_noslip is False: + from warnings import warn + warn("Only no-slip coupling is implemented", UserWarning, stacklevel=2) + raise NotImplementedError + + fluid_boundaries = { + as_dofdesc(bdtag).domain_tag: bdry + for bdtag, bdry in fluid_boundaries.items()} + wall_boundaries = { + as_dofdesc(bdtag).domain_tag: bdry + for bdtag, bdry in wall_boundaries.items()} + + # Construct boundaries for the fluid-wall interface; no gradients + # yet because that's what we're trying to compute + + state_inter_volume_trace_pairs = \ + _state_inter_volume_trace_pairs(dcoll, fluid_dd, wall_dd, + gas_model_fluid, gas_model_wall, + fluid_state, wall_state, + limiter_func_fluid, limiter_func_wall) + + # Construct interface boundaries without gradient + + fluid_interface_boundaries_no_grad = { + state_tpair.dd.domain_tag: InterfaceFluidBoundary( + state_plus=state_tpair.ext, + interface_noslip=interface_noslip, + interface_radiation=interface_radiation, + boundary_velocity=boundary_velocity) + for state_tpair in state_inter_volume_trace_pairs[wall_dd, fluid_dd]} + + wall_interface_boundaries_no_grad = { + state_tpair.dd.domain_tag: InterfaceWallBoundary( + state_plus=state_tpair.ext, + interface_noslip=interface_noslip, + interface_radiation=interface_radiation) + for state_tpair in state_inter_volume_trace_pairs[fluid_dd, wall_dd]} + + # Augment the domain boundaries with the interface boundaries + + fluid_all_boundaries_no_grad = {} + fluid_all_boundaries_no_grad.update(fluid_boundaries) + fluid_all_boundaries_no_grad.update(fluid_interface_boundaries_no_grad) + + wall_all_boundaries_no_grad = {} + wall_all_boundaries_no_grad.update(wall_boundaries) + wall_all_boundaries_no_grad.update(wall_interface_boundaries_no_grad) + + return fluid_all_boundaries_no_grad, wall_all_boundaries_no_grad + + +def add_interface_boundaries( + dcoll, + fluid_dd, wall_dd, + fluid_gas_model, wall_gas_model, + fluid_state, wall_state, + fluid_grad_cv, wall_grad_cv, + fluid_grad_temperature, wall_grad_temperature, + fluid_boundaries, wall_boundaries, + interface_noslip, interface_radiation, + *, + limiter_func_fluid=None, limiter_func_wall=None, + boundary_velocity=None, + wall_emissivity=None, sigma=None, ambient_temperature=None, + wall_penalty_amount=None): + r"""Return the interface of the subdomains for viscous fluxes. + + Used to apply the boundary fluxes at the interface between fluid and + wall domains. + + Parameters + ---------- + dcoll: class:`~grudge.discretization.DiscretizationCollection` + A discretization collection encapsulating the DG elements + + fluid_dd: :class:`grudge.dof_desc.DOFDesc` + DOF descriptor for the fluid volume. + + wall_dd: :class:`grudge.dof_desc.DOFDesc` + DOF descriptor for the wall volume. + + fluid_boundaries: + Dictionary of boundary objects for the fluid subdomain, one for each + :class:`~grudge.dof_desc.BoundaryDomainTag` that represents a domain + boundary. + + wall_boundaries: + Dictionary of boundary objects for the wall subdomain, one for each + :class:`~grudge.dof_desc.BoundaryDomainTag` that represents a domain + boundary. + + fluid_state: :class:`~mirgecom.gas_model.FluidState` + Fluid state object with the conserved state and dependent + quantities for the fluid volume. + + wall_state: :class:`~mirgecom.gas_model.FluidState` + Wall state object with the conserved state and dependent + quantities for the wall volume. + + interface_noslip: bool + If `True`, interface boundaries on the fluid side will be treated as + no-slip walls. If `False` they will be treated as slip walls. + + interface_radiation: bool + If `True`, interface includes a radiation sink term in the heat flux + on the wall side and prescribes the temperature on the fluid side. See + :class:`InterfaceWallBoundary` + for details. Additional arguments *wall_emissivity*, *sigma*, and + *ambient_temperature* are required if enabled. + + wall_emissivity: float or :class:`meshmode.dof_array.DOFArray` + Emissivity of the wall material. + + sigma: float + Stefan-Boltzmann constant. + + ambient_temperature: :class:`meshmode.dof_array.DOFArray` + Ambient temperature of the environment. + + boundary_velocity: float or :class:`meshmode.dof_array.DOFArray` + Normal velocity due to pyrolysis outgas. Only required for simplified + analysis of composite material. + + wall_penalty_amount: float + Coefficient $c$ for the interior penalty on the heat flux. See + :class:`InterfaceFluidBoundary` + for details. + + Returns + ------- + The tuple `(fluid_interface_boundaries, wall_interface_boundaries)`. + """ + if wall_penalty_amount is None: + # FIXME: After verifying the form of the penalty term, figure out what + # value makes sense to use as a default here + wall_penalty_amount = 0.05 + + # Set up the interface boundaries + + fluid_boundaries = { + as_dofdesc(bdtag).domain_tag: bdry + for bdtag, bdry in fluid_boundaries.items()} + wall_boundaries = { + as_dofdesc(bdtag).domain_tag: bdry + for bdtag, bdry in wall_boundaries.items()} + + # Exchange information + + state_inter_volume_trace_pairs = \ + _state_inter_volume_trace_pairs(dcoll, + fluid_dd, wall_dd, + fluid_gas_model, wall_gas_model, + fluid_state, wall_state, + limiter_func_fluid, limiter_func_wall) + + grad_cv_inter_vol_tpairs = _grad_cv_inter_volume_trace_pairs( + dcoll, fluid_dd, wall_dd, fluid_grad_cv, wall_grad_cv) + + grad_temperature_inter_vol_tpairs = _grad_temperature_inter_volume_trace_pairs( + dcoll, fluid_dd, wall_dd, fluid_grad_temperature, wall_grad_temperature) + + # Construct interface boundaries with temperature gradient + + import grudge.op as op + from grudge.dt_utils import characteristic_lengthscales + actx = fluid_state.cv.mass.array_context + fluid_lengthscales = characteristic_lengthscales(actx, dcoll, fluid_dd) + wall_lengthscales = characteristic_lengthscales(actx, dcoll, wall_dd) + + fluid_interface_boundaries = { + state_tpair.dd.domain_tag: InterfaceFluidBoundary( + state_plus=state_tpair.ext, + interface_noslip=interface_noslip, + interface_radiation=interface_radiation, + boundary_velocity=boundary_velocity, + grad_cv_plus=grad_cv_tpair.ext, + grad_t_plus=grad_temperature_tpair.ext, + lengthscales_minus=op.project(dcoll, fluid_dd, state_tpair.dd, + fluid_lengthscales), + flux_penalty_amount=wall_penalty_amount) + for state_tpair, grad_cv_tpair, grad_temperature_tpair in zip( + state_inter_volume_trace_pairs[wall_dd, fluid_dd], + grad_cv_inter_vol_tpairs[wall_dd, fluid_dd], + grad_temperature_inter_vol_tpairs[wall_dd, fluid_dd])} + + wall_interface_boundaries = { + state_tpair.dd.domain_tag: InterfaceWallBoundary( + state_plus=state_tpair.ext, + interface_noslip=interface_noslip, + interface_radiation=interface_radiation, + wall_emissivity=wall_emissivity, + sigma=sigma, + u_ambient=ambient_temperature, + grad_cv_plus=grad_cv_tpair.ext, + grad_t_plus=grad_temperature_tpair.ext, + lengthscales_minus=op.project(dcoll, wall_dd, state_tpair.dd, + wall_lengthscales), + flux_penalty_amount=wall_penalty_amount) + for state_tpair, grad_cv_tpair, grad_temperature_tpair in zip( + state_inter_volume_trace_pairs[fluid_dd, wall_dd], + grad_cv_inter_vol_tpairs[fluid_dd, wall_dd], + grad_temperature_inter_vol_tpairs[fluid_dd, wall_dd])} + + # Augment the domain boundaries with the interface boundaries + + fluid_all_boundaries = {} + fluid_all_boundaries.update(fluid_boundaries) + fluid_all_boundaries.update(fluid_interface_boundaries) + + wall_all_boundaries = {} + wall_all_boundaries.update(wall_boundaries) + wall_all_boundaries.update(wall_interface_boundaries) + + return fluid_all_boundaries, wall_all_boundaries diff --git a/test/test_multiphysics.py b/test/test_multiphysics.py index 12ad33caa..f2ecd4d06 100644 --- a/test/test_multiphysics.py +++ b/test/test_multiphysics.py @@ -1,4 +1,4 @@ -__copyright__ = """Copyright (C) 2022 University of Illinois Board of Trustees""" +__copyright__ = """Copyright (C) 2023 University of Illinois Board of Trustees""" __license__ = """ Permission is hereby granted, free of charge, to any person obtaining a copy @@ -20,11 +20,11 @@ THE SOFTWARE. """ +import pytest +import cantera import numpy as np from dataclasses import replace from functools import partial -import pyopencl.array as cla # noqa -import pyopencl.clmath as clmath # noqa from pytools.obj_array import make_obj_array import pymbolic as pmbl import grudge.op as op @@ -42,7 +42,7 @@ NeumannDiffusionBoundary) from grudge.dof_desc import DOFDesc, VolumeDomainTag, DISCR_TAG_BASE, DISCR_TAG_QUAD from mirgecom.discretization import create_discretization_collection -from mirgecom.eos import IdealSingleGas +from mirgecom.eos import IdealSingleGas, PyrometheusMixture from mirgecom.transport import SimpleTransport from mirgecom.fluid import make_conserved from mirgecom.gas_model import ( @@ -52,14 +52,27 @@ from mirgecom.boundary import ( AdiabaticNoslipWallBoundary, IsothermalWallBoundary, + DummyBoundary ) from mirgecom.multiphysics.thermally_coupled_fluid_wall import ( basic_coupled_ns_heat_operator as coupled_ns_heat_operator, ) +from mirgecom.navierstokes import ( + grad_t_operator, grad_cv_operator, ns_operator +) +from mirgecom.multiphysics.multiphysics_coupled_fluid_wall import ( + add_interface_boundaries_no_grad as add_multiphysics_interface_bdries_no_grad, + add_interface_boundaries as add_multiphysics_interface_bdries +) from meshmode.array_context import ( # noqa pytest_generate_tests_for_pyopencl_array_context as pytest_generate_tests) -import pytest +from mirgecom.wall_model import PorousFlowModel, PorousWallTransport +from mirgecom.materials.initializer import PorousWallInitializer +from mirgecom.materials.prescribed_porous_material import PrescribedMaterialEOS +from mirgecom.mechanisms import get_mechanism_input +from mirgecom.thermochemistry import get_pyrometheus_wrapper_class_from_cantera + import logging logger = logging.getLogger(__name__) @@ -511,7 +524,8 @@ def test_thermally_coupled_fluid_wall_with_radiation( mom_y = 0.0*fluid_nodes[0] momentum = make_obj_array([mom_x, mom_y]) - temperature = 2.0 - (2.0 - 1.33938)*(1.0 - fluid_nodes[0])/1.0 + grad_t_fluid_exact = (2.0 - 1.33938)/1.0 + temperature = 2.0 + grad_t_fluid_exact*(fluid_nodes[0] - 1.0) mass = 1.0 + 0.0*fluid_nodes[0] energy = mass*eos.get_internal_energy(temperature) @@ -543,25 +557,577 @@ def test_thermally_coupled_fluid_wall_with_radiation( DirichletDiffusionBoundary(base_solid_temp), } - wall_temperature = 1.0 + (1.0 - 1.33938)*(2.0 + solid_nodes[0])/(-2.0) + grad_t_solid_exact = (1.0 - 1.33938)/(-2.0) + wall_temperature = 1.0 + grad_t_solid_exact*(2.0 + solid_nodes[0]) + + fluid_rhs, wall_rhs, fluid_grad_cv, fluid_grad_t, solid_grad_t = \ + coupled_ns_heat_operator( + dcoll, gas_model, dd_vol_fluid, dd_vol_solid, fluid_boundaries, + solid_boundaries, fluid_state, wall_kappa, wall_temperature, + time=0.0, quadrature_tag=quadrature_tag, interface_noslip=False, + interface_radiation=True, sigma=2.0, + ambient_temperature=0.0, wall_emissivity=wall_emissivity, + return_gradients=True) - fluid_rhs, wall_energy_rhs = coupled_ns_heat_operator( - dcoll, gas_model, dd_vol_fluid, dd_vol_solid, fluid_boundaries, - solid_boundaries, fluid_state, wall_kappa, wall_temperature, - time=0.0, quadrature_tag=quadrature_tag, interface_noslip=False, - interface_radiation=True, sigma=2.0, - ambient_temperature=0.0, wall_emissivity=wall_emissivity, - ) + if visualize: + from grudge.shortcuts import make_visualizer + viz_fluid = make_visualizer(dcoll, order, volume_dd=dd_vol_fluid) + viz_solid = make_visualizer(dcoll, order, volume_dd=dd_vol_solid) + if use_overintegration: + viz_suffix = f"over_{order}" + else: + viz_suffix = f"{order}" + + viz_fluid.write_vtk_file( + f"thermally_coupled_radiation_{viz_suffix}_fluid.vtu", [ + ("temperature", fluid_state.dv.temperature), + ("grad_T", fluid_grad_t), + ], overwrite=True) + viz_solid.write_vtk_file( + f"thermally_coupled_radiation_{viz_suffix}_solid.vtu", [ + ("temperature", wall_temperature), + ("grad_T", solid_grad_t), + ], overwrite=True) + + fluid_grad = \ + (fluid_grad_t[0] - grad_t_fluid_exact)/fluid_state.dv.temperature + solid_grad = \ + (solid_grad_t[0] - grad_t_solid_exact)/wall_temperature + + assert actx.to_numpy(op.norm(dcoll, fluid_grad, np.inf)) < 1e-9 + assert actx.to_numpy(op.norm(dcoll, solid_grad, np.inf)) < 1e-9 # Check that steady-state solution has 0 RHS fluid_rhs = fluid_rhs.energy - solid_rhs = wall_energy_rhs/(wall_cp*wall_rho) + solid_rhs = wall_rhs/(wall_cp*wall_rho) assert actx.to_numpy(op.norm(dcoll, fluid_rhs, np.inf)) < 1e-4 assert actx.to_numpy(op.norm(dcoll, solid_rhs, np.inf)) < 1e-4 +@pytest.mark.parametrize("order", [1, 3, 5]) +# @pytest.mark.parametrize("use_overintegration", [False, True]) +@pytest.mark.parametrize("use_overintegration", [False]) +@pytest.mark.parametrize("use_radiation", [False, True]) +def test_multiphysics_coupled_fluid_wall_for_temperature( + actx_factory, order, use_overintegration, use_radiation, visualize=False): + """Check the NS-coupled fluid/wall interface with radiation. + + Analytic solution prescribed as initial condition, then the RHS is assessed + to ensure that it is nearly zero. + """ + actx = actx_factory() + + dim = 2 + + nelems = 15 + global_mesh = get_box_mesh(dim, -0.2, 0.1, nelems) + + mgrp, = global_mesh.groups + x = global_mesh.vertices[0, mgrp.vertex_indices] + x_elem_avg = np.sum(x, axis=1)/x.shape[0] + volume_to_elements = { + "Fluid": np.where(x_elem_avg > 0)[0], + "Solid": np.where(x_elem_avg <= 0)[0]} + + from meshmode.mesh.processing import partition_mesh + volume_meshes = partition_mesh(global_mesh, volume_to_elements) + + dcoll = create_discretization_collection( + actx, volume_meshes, order=order, quadrature_order=2*order+1) + + if use_overintegration: + quadrature_tag = DISCR_TAG_QUAD + else: + quadrature_tag = None + + dd_vol_fluid = DOFDesc(VolumeDomainTag("Fluid"), DISCR_TAG_BASE) + dd_vol_solid = DOFDesc(VolumeDomainTag("Solid"), DISCR_TAG_BASE) + + fluid_nodes = actx.thaw(dcoll.nodes(dd=dd_vol_fluid)) + solid_nodes = actx.thaw(dcoll.nodes(dd=dd_vol_solid)) + + # Use Cantera for initialization + cantera_soln = cantera.Solution(name="gas", + yaml=get_mechanism_input("air_3sp")) + nspecies = cantera_soln.n_species + y_species = np.zeros(nspecies) + y_species[cantera_soln.species_index("N2")] = 1.0 + + pyrometheus_mechanism = get_pyrometheus_wrapper_class_from_cantera( + cantera_soln, temperature_niter=3)(actx.np) + + eos = PyrometheusMixture(pyrometheus_mechanism, temperature_guess=500.0) + + # Gas model + + fluid_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=0.1, + species_diffusivity=np.zeros(nspecies,)) + gas_model_fluid = GasModel(eos=eos, transport=fluid_transport) + + virgin_mass = 10.0 + + def enthalpy_func(temperature): + return 1000.0*temperature + + def heat_capacity_func(temperature): + return 1000.0 + + def thermal_conductivity_func(temperature=None): + return 0.2 if use_radiation else 1.0 + + def permeability_func(tau=None): + return 1.0e-10 + + def emissivity_func(tau=None): + return 1.0 + + def tortuosity_func(tau=None): + return 1.2 + + def volume_fraction_func(tau=None): + return 0.1 + + def decomposition_progress_func(mass): + return 1.0 + mass*0.0 + + my_material = PrescribedMaterialEOS(enthalpy_func, heat_capacity_func, + thermal_conductivity_func, volume_fraction_func, permeability_func, + emissivity_func, tortuosity_func, decomposition_progress_func) + solid_density = virgin_mass + solid_nodes[0]*0.0 + + kappa = 0.2 if use_radiation else 1.0 + base_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=kappa, + species_diffusivity=np.zeros(nspecies,)) + solid_transport = PorousWallTransport(base_transport=base_transport) + gas_model_solid = PorousFlowModel(eos=eos, transport=solid_transport, + wall_eos=my_material) + + base_fluid_temp = 1500.0 + base_solid_temp = 300.0 + temp_interface = 400.468540568359649 if use_radiation else 500.0 + + # Fluid cv + + grad_t_fluid_exact = (base_fluid_temp - temp_interface)/(0.1) + temperature = temp_interface + grad_t_fluid_exact*fluid_nodes[0] + mass = 1.0 + 0.0*fluid_nodes[0] + energy = mass*eos.get_internal_energy(temperature, y_species) + mom_x = 0.0*fluid_nodes[0] + mom_y = 0.0*fluid_nodes[0] + momentum = make_obj_array([mom_x, mom_y]) + + cv = make_conserved(dim=dim, mass=mass, energy=energy, + momentum=momentum, species_mass=y_species*mass) + fluid_state = make_fluid_state(cv=cv, gas_model=gas_model_fluid, + temperature_seed=500.0) + + # Solid cv + + grad_t_solid_exact = (base_solid_temp - temp_interface)/(-0.2) + solid_temperature = temp_interface + grad_t_solid_exact*solid_nodes[0] + + sample_init = PorousWallInitializer(density=1.0, + temperature=solid_temperature, species=y_species, + material_densities=solid_density) + + wv = sample_init(dim, solid_nodes, gas_model_solid) + solid_state = make_fluid_state(cv=wv, gas_model=gas_model_solid, + material_densities=solid_density, temperature_seed=solid_temperature) + + # Setup boundaries and interface + + fluid_boundaries = { + dd_vol_fluid.trace("-2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_fluid.trace("+2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_fluid.trace("+1").domain_tag: + IsothermalWallBoundary(wall_temperature=base_fluid_temp), + } + + solid_boundaries = { + dd_vol_solid.trace("-2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_solid.trace("+2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_solid.trace("-1").domain_tag: + IsothermalWallBoundary(wall_temperature=base_solid_temp), + } + + fluid_all_boundaries_no_grad, solid_all_boundaries_no_grad = \ + add_multiphysics_interface_bdries_no_grad( + dcoll, dd_vol_fluid, dd_vol_solid, + gas_model_fluid, gas_model_solid, + fluid_state, solid_state, + fluid_boundaries, solid_boundaries, + interface_noslip=True, interface_radiation=use_radiation) + + fluid_grad_cv = grad_cv_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid) + + fluid_grad_temperature = grad_t_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid) + + solid_grad_cv = grad_cv_operator( + dcoll, gas_model_solid, solid_all_boundaries_no_grad, solid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_solid) + + solid_grad_temperature = grad_t_operator( + dcoll, gas_model_solid, solid_all_boundaries_no_grad, solid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_solid) + + fluid_all_boundaries, solid_all_boundaries = \ + add_multiphysics_interface_bdries( + dcoll, dd_vol_fluid, dd_vol_solid, + gas_model_fluid, gas_model_solid, + fluid_state, solid_state, + fluid_grad_cv, solid_grad_cv, + fluid_grad_temperature, solid_grad_temperature, + fluid_boundaries, solid_boundaries, + interface_noslip=True, interface_radiation=use_radiation, + wall_emissivity=1.0, sigma=5.67e-8, ambient_temperature=300.0) + + fluid_rhs = ns_operator( + dcoll, gas_model_fluid, fluid_state, fluid_all_boundaries, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid, + grad_cv=fluid_grad_cv, grad_t=fluid_grad_temperature, + inviscid_terms_on=False) + + solid_rhs = ns_operator( + dcoll, gas_model_solid, solid_state, solid_all_boundaries, + quadrature_tag=quadrature_tag, dd=dd_vol_solid, + grad_cv=solid_grad_cv, grad_t=solid_grad_temperature, + inviscid_terms_on=False) + + if visualize: + from grudge.shortcuts import make_visualizer + viz_fluid = make_visualizer(dcoll, order+3, volume_dd=dd_vol_fluid) + viz_solid = make_visualizer(dcoll, order+3, volume_dd=dd_vol_solid) + if use_overintegration: + viz_suffix = f"over_{order}" + else: + viz_suffix = f"{order}" + + if use_radiation: + coupling_suffix = "radiation" + else: + coupling_suffix = "temperature" + + viz_fluid.write_vtk_file( + f"multiphysics_coupled_{coupling_suffix}_{viz_suffix}_fluid.vtu", [ + ("cv", fluid_state.cv), + ("dv", fluid_state.dv), + ("grad_T", fluid_grad_temperature), + ("rhs", fluid_rhs), + ], overwrite=True) + viz_solid.write_vtk_file( + f"multiphysics_coupled_{coupling_suffix}_{viz_suffix}_solid.vtu", [ + ("cv", solid_state.cv), + ("dv", solid_state.dv), + ("grad_T", solid_grad_temperature), + ("rhs", solid_rhs), + ], overwrite=True) + + # Check gradients + + check_fluid_grad_t = \ + (fluid_grad_temperature[0] - grad_t_fluid_exact)/fluid_state.temperature + check_solid_grad_t = \ + (solid_grad_temperature[0] - grad_t_solid_exact)/solid_state.temperature + + assert actx.to_numpy(op.norm(dcoll, check_fluid_grad_t, np.inf)) < 1e-6 + assert actx.to_numpy(op.norm(dcoll, check_solid_grad_t, np.inf)) < 1e-6 + + # Check that steady-state solution has 0 RHS + + assert actx.to_numpy(op.norm(dcoll, fluid_rhs.energy/cv.energy, np.inf)) < 1e-6 + assert actx.to_numpy(op.norm(dcoll, solid_rhs.energy/wv.energy, np.inf)) < 1e-6 + + +@pytest.mark.parametrize("order", [1, 3, 5]) +# @pytest.mark.parametrize("use_overintegration", [False, True]) +@pytest.mark.parametrize("use_overintegration", [False]) +def test_multiphysics_coupled_fluid_wall_for_species_diffusion( + actx_factory, order, use_overintegration, visualize=False): + """Check the NS-coupled fluid/wall interface for species diffusion. + + Analytic solution prescribed as initial condition, then the RHS is assessed + to ensure that it is nearly zero. + """ + actx = actx_factory() + + dim = 2 + + nelems = 30 + global_mesh = get_box_mesh(dim, -1, 2, nelems) + + mgrp, = global_mesh.groups + x = global_mesh.vertices[0, mgrp.vertex_indices] + x_elem_avg = np.sum(x, axis=1)/x.shape[0] + volume_to_elements = { + "Fluid": np.where(x_elem_avg > 0)[0], + "Solid": np.where(x_elem_avg <= 0)[0]} + + from meshmode.mesh.processing import partition_mesh + volume_meshes = partition_mesh(global_mesh, volume_to_elements) + + dcoll = create_discretization_collection( + actx, volume_meshes, order=order, quadrature_order=2*order+1) + + if use_overintegration: + quadrature_tag = DISCR_TAG_QUAD + else: + quadrature_tag = None + + dd_vol_fluid = DOFDesc(VolumeDomainTag("Fluid"), DISCR_TAG_BASE) + dd_vol_solid = DOFDesc(VolumeDomainTag("Solid"), DISCR_TAG_BASE) + + fluid_nodes = actx.thaw(dcoll.nodes(dd=dd_vol_fluid)) + solid_nodes = actx.thaw(dcoll.nodes(dd=dd_vol_solid)) + + # Pyrometheus initialization + mech_input = get_mechanism_input("air_3sp") + cantera_soln = cantera.Solution(name="gas", yaml=mech_input) + pyro_obj = get_pyrometheus_wrapper_class_from_cantera( + cantera_soln, temperature_niter=3)(actx.np) + + nspecies = pyro_obj.num_species + + eos = PyrometheusMixture(pyro_obj, temperature_guess=300.0) + + fluid_diffusivity = 0.006 + solid_diffusivity = 0.001 + + # Gas model + + fluid_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=0.0, + species_diffusivity=np.zeros(nspecies,) + fluid_diffusivity) + gas_model_fluid = GasModel(eos=eos, transport=fluid_transport) + + virgin_mass = 10.0 + + def enthalpy_func(temperature): + return 1000.0*temperature + + def heat_capacity_func(temperature): + return 1000.0 + + def thermal_conductivity_func(temperature=None): + return 0.2 + + def permeability_func(tau=None): + return 1.0e-10 + + def emissivity_func(tau=None): + return 1.0 + + def tortuosity_func(tau=None): + return 1.2 + + def volume_fraction_func(tau=None): + return 0.1 + + def decomposition_progress_func(mass): + return 1.0 + mass*0.0 + + my_material = PrescribedMaterialEOS(enthalpy_func, heat_capacity_func, + thermal_conductivity_func, volume_fraction_func, permeability_func, + emissivity_func, tortuosity_func, decomposition_progress_func) + solid_density = virgin_mass + solid_nodes[0]*0.0 + + base_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=0.0, + species_diffusivity=np.zeros(nspecies,) + solid_diffusivity) + solid_transport = PorousWallTransport(base_transport=base_transport) + gas_model_solid = PorousFlowModel(eos=eos, transport=solid_transport, + wall_eos=my_material) + + solid_diffusivity = solid_diffusivity/tortuosity_func() + + epsilon = 0.9 + diff_ratio = fluid_diffusivity/solid_diffusivity + y_species_interface = 2.0*epsilon/(2.0*epsilon + diff_ratio) + + # Fluid cv + + zeros = fluid_nodes[0]*0.0 + y = make_obj_array([zeros, zeros, zeros]) + + fluid_grad_y0_exact = +y_species_interface/2.0 + fluid_grad_y1_exact = -fluid_grad_y0_exact + + y[0] = 1.0 + fluid_grad_y0_exact*(fluid_nodes[0] - 2.0) + y[1] = 1.0 - y[0] + y[2] = y[0]*0.0 + + temperature = 300.0 + zeros + mass = 1.2 + zeros + energy = mass*eos.get_internal_energy(temperature, y) + momentum = make_obj_array([0.0*fluid_nodes[0], 0.0*fluid_nodes[0]]) + species_mass = mass*y + + cv = make_conserved(dim=dim, mass=mass, energy=energy, momentum=momentum, + species_mass=species_mass) + fluid_state = make_fluid_state(cv=cv, gas_model=gas_model_fluid, + temperature_seed=temperature) + + # Solid cv + + zeros = solid_nodes[0]*0.0 + y = make_obj_array([zeros, zeros, zeros]) + y[0] = (1.0 - y_species_interface)*(solid_nodes[0] + 1.0) + y[1] = 1.0 - y[0] + y[2] = y[0]*0.0 + + solid_grad_y0_exact = (1.0 - y_species_interface) + solid_grad_y1_exact = -solid_grad_y0_exact + + solid_density = 10.0 + zeros + sample_init = PorousWallInitializer(density=1.2, temperature=300.0, + species=y, material_densities=solid_density) + + wv = sample_init(2, solid_nodes, gas_model_solid) + solid_state = make_fluid_state(cv=wv, gas_model=gas_model_solid, + material_densities=solid_density, temperature_seed=300.0) + + # Setup boundaries and interface + + fluid_boundaries = { + dd_vol_fluid.trace("-2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_fluid.trace("+2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_fluid.trace("+1").domain_tag: DummyBoundary(), + } + + solid_boundaries = { + dd_vol_solid.trace("-2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_solid.trace("+2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_solid.trace("-1").domain_tag: DummyBoundary() + } + + fluid_all_boundaries_no_grad, solid_all_boundaries_no_grad = \ + add_multiphysics_interface_bdries_no_grad( + dcoll, dd_vol_fluid, dd_vol_solid, + gas_model_fluid, gas_model_solid, + fluid_state, solid_state, + fluid_boundaries, solid_boundaries, + interface_noslip=True, interface_radiation=True) + + fluid_grad_cv = grad_cv_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid) + + fluid_grad_temperature = grad_t_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid) + + solid_grad_cv = grad_cv_operator( + dcoll, gas_model_solid, solid_all_boundaries_no_grad, solid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_solid) + + solid_grad_temperature = grad_t_operator( + dcoll, gas_model_solid, solid_all_boundaries_no_grad, solid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_solid) + + from mirgecom.fluid import species_mass_fraction_gradient + fluid_grad_y = species_mass_fraction_gradient(fluid_state.cv, fluid_grad_cv) + solid_grad_y = species_mass_fraction_gradient(solid_state.cv, solid_grad_cv) + + assert actx.to_numpy( + op.norm(dcoll, fluid_grad_y[0][0] - fluid_grad_y0_exact, np.inf)) < 1e-6 + assert actx.to_numpy( + op.norm(dcoll, fluid_grad_y[1][0] - fluid_grad_y1_exact, np.inf)) < 1e-6 + assert actx.to_numpy( + op.norm(dcoll, solid_grad_y[0][0] - solid_grad_y0_exact, np.inf)) < 1e-6 + assert actx.to_numpy( + op.norm(dcoll, solid_grad_y[1][0] - solid_grad_y1_exact, np.inf)) < 1e-6 + + check_fluid_grad_t = fluid_grad_temperature[0] # uniform temperature + check_solid_grad_t = solid_grad_temperature[0] # uniform temperature + + assert actx.to_numpy(op.norm(dcoll, check_fluid_grad_t, np.inf)) < 1e-6 + assert actx.to_numpy(op.norm(dcoll, check_solid_grad_t, np.inf)) < 1e-6 + + fluid_all_boundaries, solid_all_boundaries = \ + add_multiphysics_interface_bdries( + dcoll, dd_vol_fluid, dd_vol_solid, + gas_model_fluid, gas_model_solid, + fluid_state, solid_state, + fluid_grad_cv, solid_grad_cv, + fluid_grad_temperature, solid_grad_temperature, + fluid_boundaries, solid_boundaries, + interface_noslip=True, interface_radiation=True, + wall_emissivity=1.0, sigma=5.67e-8, ambient_temperature=300.0) + + fluid_rhs = ns_operator( + dcoll, gas_model_fluid, fluid_state, fluid_all_boundaries, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid, + grad_cv=fluid_grad_cv, grad_t=fluid_grad_temperature, + inviscid_terms_on=False) + + solid_rhs = ns_operator( + dcoll, gas_model_solid, solid_state, solid_all_boundaries, + quadrature_tag=quadrature_tag, dd=dd_vol_solid, + grad_cv=solid_grad_cv, grad_t=solid_grad_temperature, + inviscid_terms_on=False) + + if visualize: + from grudge.shortcuts import make_visualizer + viz_fluid = make_visualizer(dcoll, order+3, volume_dd=dd_vol_fluid) + viz_solid = make_visualizer(dcoll, order+3, volume_dd=dd_vol_solid) + if use_overintegration: + viz_suffix = f"over_{order}" + else: + viz_suffix = f"{order}" + + viz_fluid.write_vtk_file( + f"multiphysics_coupled_species_{viz_suffix}_fluid.vtu", [ + ("rhoE", fluid_state.cv.energy), + ("rho", fluid_state.cv.mass), + ("Y0", fluid_state.cv.species_mass_fractions[0]), + ("Y1", fluid_state.cv.species_mass_fractions[1]), + ("Y2", fluid_state.cv.species_mass_fractions[2]), + ("Dij", fluid_state.tv.species_diffusivity[0]), + ("dv", fluid_state.dv), + ("grad_rho", fluid_grad_cv.mass), + ("grad_Y0", fluid_grad_y[0]), + ("grad_Y1", fluid_grad_y[1]), + ("grad_Y2", fluid_grad_y[2]), + ("rhs", fluid_rhs), + ], overwrite=True) + viz_solid.write_vtk_file( + f"multiphysics_coupled_species_{viz_suffix}_solid.vtu", [ + ("eps_rhoE", solid_state.cv.energy), + ("eps_rho", solid_state.cv.mass), + ("rho", solid_state.cv.mass/solid_state.wv.void_fraction), + ("Y0", solid_state.cv.species_mass_fractions[0]), + ("Y1", solid_state.cv.species_mass_fractions[1]), + ("Y2", solid_state.cv.species_mass_fractions[2]), + ("Dij", solid_state.tv.species_diffusivity[0]), + ("eps", solid_state.wv.void_fraction), + ("dv", solid_state.dv), + ("grad_rho", solid_grad_cv.mass), + ("grad_Y0", solid_grad_y[0]), + ("grad_Y1", solid_grad_y[1]), + ("grad_Y2", solid_grad_y[2]), + ("rhs", solid_rhs), + ], overwrite=True) + + # Check that steady-state solution has 0 RHS + + for i in range(0, nspecies): + assert \ + actx.to_numpy(op.norm(dcoll, fluid_rhs.species_mass[i], np.inf)) < 1e-10 + assert \ + actx.to_numpy(op.norm(dcoll, solid_rhs.species_mass[i], np.inf)) < 1e-10 + +# assert actx.to_numpy(op.norm(dcoll, fluid_rhs.energy, np.inf)) < 1e-3 +# assert actx.to_numpy(op.norm(dcoll, solid_rhs.energy, np.inf)) < 1e-3 + + fluid_energy_rhs = fluid_rhs.energy/fluid_state.cv.energy + solid_energy_rhs = solid_rhs.energy/solid_state.cv.energy + assert actx.to_numpy(op.norm(dcoll, fluid_energy_rhs, np.inf)) < 1e-9 + assert actx.to_numpy(op.norm(dcoll, solid_energy_rhs, np.inf)) < 1e-9 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: diff --git a/test/test_wallmodel.py b/test/test_wallmodel.py new file mode 100644 index 000000000..1cf215ae0 --- /dev/null +++ b/test/test_wallmodel.py @@ -0,0 +1,340 @@ +__copyright__ = """Copyright (C) 2023 University of Illinois Board of Trustees""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pytest +import cantera +import numpy as np +from pytools.obj_array import make_obj_array +import grudge.op as op +from mirgecom.simutil import get_box_mesh +from grudge.dof_desc import DOFDesc, VolumeDomainTag, DISCR_TAG_BASE, DISCR_TAG_QUAD +from mirgecom.discretization import create_discretization_collection +from mirgecom.eos import PyrometheusMixture +from mirgecom.transport import SimpleTransport +from mirgecom.fluid import make_conserved +from mirgecom.gas_model import GasModel, make_fluid_state +from mirgecom.boundary import AdiabaticNoslipWallBoundary +from mirgecom.navierstokes import ( + grad_t_operator, grad_cv_operator, ns_operator +) +from mirgecom.multiphysics.multiphysics_coupled_fluid_wall import ( + add_interface_boundaries_no_grad as add_multiphysics_interface_bdries_no_grad, + add_interface_boundaries as add_multiphysics_interface_bdries +) +from meshmode.array_context import ( # noqa + pytest_generate_tests_for_pyopencl_array_context + as pytest_generate_tests) +from mirgecom.wall_model import PorousFlowModel, PorousWallTransport +from mirgecom.materials.initializer import PorousWallInitializer +from mirgecom.mechanisms import get_mechanism_input +from mirgecom.thermochemistry import get_pyrometheus_wrapper_class_from_cantera + + +import logging +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("order", [2, 3, 5]) +@pytest.mark.parametrize("use_overintegration", [False]) +@pytest.mark.parametrize("my_material", ["composite"]) +def test_wallmodel( + actx_factory, order, use_overintegration, my_material, visualize=True): + """Check the wall degradation model.""" + actx = actx_factory() + + dim = 2 + + nelems = 20 + global_mesh = get_box_mesh(2, -0.1, 0.1, nelems) + + mgrp, = global_mesh.groups + x = global_mesh.vertices[0, mgrp.vertex_indices] + x_elem_avg = np.sum(x, axis=1)/x.shape[0] + volume_to_elements = { + "Fluid": np.where(x_elem_avg > 0)[0], + "Solid": np.where(x_elem_avg <= 0)[0]} + + from meshmode.mesh.processing import partition_mesh + volume_meshes = partition_mesh(global_mesh, volume_to_elements) + + dcoll = create_discretization_collection( + actx, volume_meshes, order=order, quadrature_order=2*order+1) + quadrature_tag = DISCR_TAG_QUAD if use_overintegration else None + + dd_vol_fluid = DOFDesc(VolumeDomainTag("Fluid"), DISCR_TAG_BASE) + dd_vol_solid = DOFDesc(VolumeDomainTag("Solid"), DISCR_TAG_BASE) + + fluid_nodes = actx.thaw(dcoll.nodes(dd=dd_vol_fluid)) + solid_nodes = actx.thaw(dcoll.nodes(dd=dd_vol_solid)) + + fluid_zeros = actx.np.zeros_like(fluid_nodes[0]) + solid_zeros = actx.np.zeros_like(solid_nodes[0]) + + from grudge.discretization import filter_part_boundaries + from grudge.reductions import integral + solid_dd_list = filter_part_boundaries(dcoll, volume_dd=dd_vol_solid, + neighbor_volume_dd=dd_vol_fluid) + solid_normal = actx.thaw(dcoll.normal(solid_dd_list[0])) + + fluid_dd_list = filter_part_boundaries(dcoll, volume_dd=dd_vol_fluid, + neighbor_volume_dd=dd_vol_solid) + fluid_normal = actx.thaw(dcoll.normal(fluid_dd_list[0])) + + interface_zeros = actx.np.zeros_like(solid_normal[0]) + + integral_volume = integral(dcoll, dd_vol_solid, fluid_zeros + 1.0) + integral_surface = integral(dcoll, solid_dd_list[0], interface_zeros + 1.0) + + print(integral_volume) + print(integral_surface) + + # {{{ Pyrometheus initialization + + mech_input = get_mechanism_input("uiuc_8sp_phenol") + cantera_soln = cantera.Solution(name="gas", yaml=mech_input) + pyro_obj = get_pyrometheus_wrapper_class_from_cantera( + cantera_soln, temperature_niter=3)(actx.np) + + nspecies = pyro_obj.num_species + + x = make_obj_array([0.0 for i in range(nspecies)]) + x[cantera_soln.species_index("O2")] = 0.21 + x[cantera_soln.species_index("N2")] = 0.79 + + cantera_soln.TPX = 900.0, 101325.0, x + _, _, y = cantera_soln.TDY + + eos = PyrometheusMixture(pyro_obj, temperature_guess=900.0) + + # }}} + + # {{{ Initialize wall model + + if my_material == "fiber": + import mirgecom.materials.carbon_fiber as material_sample + material = material_sample.FiberEOS(char_mass=0., virgin_mass=160.) + decomposition = material_sample.Y3_Oxidation_Model(wall_material=material) + solid_density = 0.1*1600.0 + solid_zeros + + if my_material == "composite": + import mirgecom.materials.tacot as material_sample + material = material_sample.TacotEOS(char_mass=220., virgin_mass=280.) + decomposition = material_sample.Pyrolysis() + solid_density = np.empty((3,), dtype=object) + solid_density[0] = 30.0 + solid_zeros + solid_density[1] = 90.0 + solid_zeros + solid_density[2] = 160. + solid_zeros + + # }}} + + # {{{ Gas model + + fluid_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=1.0, + species_diffusivity=np.zeros(nspecies,)) + gas_model_fluid = GasModel(eos=eos, transport=fluid_transport) + + base_transport = SimpleTransport(viscosity=0.0, thermal_conductivity=0.0, + species_diffusivity=np.zeros(nspecies,)) + solid_transport = PorousWallTransport(base_transport=base_transport) + gas_model_solid = PorousFlowModel(eos=eos, transport=solid_transport, + wall_eos=material) + + # }}} + + # {{{ Fluid cv + + # here, fluid temperature is given by a parabolic function: + a = -183708.0 + b = 36741.6 + c = 900.0 + temperature = a*fluid_nodes[0]**2 + b*fluid_nodes[0] + c + species_mass_fractions = y + fluid_zeros + mass = eos.get_density(pressure=101325.0, temperature=temperature, + species_mass_fractions=species_mass_fractions) + energy = mass*eos.get_internal_energy(temperature, species_mass_fractions) + momentum = make_obj_array([0.0*fluid_nodes[0], 0.0*fluid_nodes[0]]) + species_mass = mass*species_mass_fractions + + cv = make_conserved(dim=dim, mass=mass, energy=energy, momentum=momentum, + species_mass=species_mass) + fluid_state = make_fluid_state(cv=cv, gas_model=gas_model_fluid, + temperature_seed=temperature) + + # }}} + + # {{{ Solid cv + + temperature = 900.0 + solid_zeros + species_mass_fractions = y + solid_zeros + mass = eos.get_density(pressure=101325.0, temperature=temperature, + species_mass_fractions=species_mass_fractions) + sample_init = PorousWallInitializer(density=mass, temperature=temperature, + species=species_mass_fractions, material_densities=solid_density) + + wv = sample_init(2, solid_nodes, gas_model_solid) + solid_state = make_fluid_state(cv=wv, gas_model=gas_model_solid, + material_densities=solid_density, temperature_seed=900.0) + + # }}} + + # {{{ Setup boundaries and interface + + fluid_boundaries = { + dd_vol_fluid.trace("-2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_fluid.trace("+2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_fluid.trace("+1").domain_tag: AdiabaticNoslipWallBoundary(), + } + + solid_boundaries = { + dd_vol_solid.trace("-2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_solid.trace("+2").domain_tag: AdiabaticNoslipWallBoundary(), + dd_vol_solid.trace("-1").domain_tag: AdiabaticNoslipWallBoundary(), + } + + fluid_all_boundaries_no_grad, solid_all_boundaries_no_grad = \ + add_multiphysics_interface_bdries_no_grad( + dcoll, dd_vol_fluid, dd_vol_solid, gas_model_fluid, gas_model_solid, + fluid_state, solid_state, fluid_boundaries, solid_boundaries, + interface_noslip=True, interface_radiation=True) + + fluid_grad_cv = grad_cv_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid) + + fluid_grad_t = grad_t_operator( + dcoll, gas_model_fluid, fluid_all_boundaries_no_grad, fluid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid) + + solid_grad_cv = grad_cv_operator( + dcoll, gas_model_solid, solid_all_boundaries_no_grad, solid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_solid) + + solid_grad_t = grad_t_operator( + dcoll, gas_model_solid, solid_all_boundaries_no_grad, solid_state, + quadrature_tag=quadrature_tag, dd=dd_vol_solid) + + tol = 1e-8 + + check_fluid_grad_t = fluid_grad_t[0] - (2.0*a*fluid_nodes[0] + b) + assert actx.to_numpy( + op.norm(dcoll, check_fluid_grad_t/fluid_state.temperature, np.inf)) < tol + assert actx.to_numpy( + op.norm(dcoll, fluid_grad_t[1]/fluid_state.temperature, np.inf)) < tol + assert actx.to_numpy( + op.norm(dcoll, solid_grad_t[0]/solid_state.temperature, np.inf)) < tol + assert actx.to_numpy( + op.norm(dcoll, solid_grad_t[1]/solid_state.temperature, np.inf)) < tol + + # }}} + + fluid_all_boundaries, solid_all_boundaries = \ + add_multiphysics_interface_bdries( + dcoll, dd_vol_fluid, dd_vol_solid, gas_model_fluid, gas_model_solid, + fluid_state, solid_state, fluid_grad_cv, solid_grad_cv, + fluid_grad_t, solid_grad_t, fluid_boundaries, solid_boundaries, + interface_noslip=True, interface_radiation=True, + wall_emissivity=1.0, sigma=5.67e-8, + ambient_temperature=300.0) + + fluid_rhs = ns_operator( + dcoll, gas_model_fluid, fluid_state, fluid_all_boundaries, + quadrature_tag=quadrature_tag, dd=dd_vol_fluid, + grad_cv=fluid_grad_cv, grad_t=fluid_grad_t, + inviscid_terms_on=False) + + solid_rhs = ns_operator( + dcoll, gas_model_solid, solid_state, solid_all_boundaries, + quadrature_tag=quadrature_tag, dd=dd_vol_solid, + grad_cv=solid_grad_cv, grad_t=solid_grad_t, + inviscid_terms_on=False) + + def blowing_velocity(cv, source): + + # volume integral of the source terms + integral_volume_source = integral(dcoll, dd_vol_solid, source) + + # restrict to coupled surface + surface_density = op.project(dcoll, dd_vol_solid, solid_dd_list[0], cv.mass) + + # surface integral of the density + integral_surface_density = integral(dcoll, solid_dd_list[0], surface_density) + + # boundary velocity + bndry_velocity = \ + integral_volume_source/integral_surface_density + interface_zeros + + return bndry_velocity, fluid_normal + + # ~~~ + if my_material == "composite": + solid_mass_rhs = decomposition.get_source_terms( + temperature=solid_state.temperature, + chi=solid_state.wv.material_densities) + + sample_source_gas = -sum(solid_mass_rhs) + + assert actx.to_numpy( + op.norm(dcoll, solid_mass_rhs[0] + 26.76761185399653, np.inf)) < tol + assert actx.to_numpy( + op.norm(dcoll, solid_mass_rhs[1] + 2.03565420370596, np.inf)) < tol + assert actx.to_numpy( + op.norm(dcoll, sample_source_gas - 28.80326605770249, np.inf)) < tol + + ref_bnd_velocity = 9.216295153230408 + bnd_velocity, normal = blowing_velocity(solid_state.cv, -sample_source_gas) + normal_velocity = bnd_velocity*normal[0] + assert actx.to_numpy( + op.norm(dcoll, normal_velocity - ref_bnd_velocity, np.inf)) < 1e-6 + + if visualize: + from grudge.shortcuts import make_visualizer + viz_fluid = make_visualizer(dcoll, 2*order+1, volume_dd=dd_vol_fluid) + viz_solid = make_visualizer(dcoll, 2*order+1, volume_dd=dd_vol_solid) + if use_overintegration: + viz_suffix = f"over_{order}" + else: + viz_suffix = f"{order}" + + i_O2 = cantera_soln.species_index("O2") # noqa N806 + i_N2 = cantera_soln.species_index("N2") # noqa N806 + + viz_fluid.write_vtk_file( + f"multiphysics_coupled_species_{viz_suffix}_fluid.vtu", [ + ("rho", fluid_state.cv.mass), + ("rhoE", fluid_state.cv.energy), + ("Y_O2", fluid_state.cv.species_mass_fractions[i_O2]), + ("Y_N2", fluid_state.cv.species_mass_fractions[i_N2]), + ("dv", fluid_state.dv), + ("grad_T", fluid_grad_t), + ("rhs", fluid_rhs), + ], overwrite=True) + viz_solid.write_vtk_file( + f"multiphysics_coupled_species_{viz_suffix}_solid.vtu", [ + ("eps_rho", solid_state.cv.mass), + ("eps_rhoE", solid_state.cv.energy), + ("rho", solid_state.cv.mass/solid_state.wv.void_fraction), + ("Y_O2", solid_state.cv.species_mass_fractions[i_O2]), + ("Y_N2", solid_state.cv.species_mass_fractions[i_N2]), + ("dv", solid_state.dv), + ("grad_T", solid_grad_t), + ("rhs", solid_rhs), + ], overwrite=True) From 5d93e2a94e93fc88c763bb1c22f6973cb5e9ccc5 Mon Sep 17 00:00:00 2001 From: Tulio Date: Tue, 10 Oct 2023 12:17:28 -0500 Subject: [PATCH 2/2] Remove lengthscale_minus --- mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py b/mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py index 6a510b6cb..74b48f733 100644 --- a/mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py +++ b/mirgecom/multiphysics/multiphysics_coupled_fluid_wall.py @@ -1007,8 +1007,8 @@ def add_interface_boundaries( boundary_velocity=boundary_velocity, grad_cv_plus=grad_cv_tpair.ext, grad_t_plus=grad_temperature_tpair.ext, - lengthscales_minus=op.project(dcoll, fluid_dd, state_tpair.dd, - fluid_lengthscales), + # lengthscales_minus=op.project(dcoll, fluid_dd, state_tpair.dd, + # fluid_lengthscales), flux_penalty_amount=wall_penalty_amount) for state_tpair, grad_cv_tpair, grad_temperature_tpair in zip( state_inter_volume_trace_pairs[wall_dd, fluid_dd], @@ -1025,8 +1025,8 @@ def add_interface_boundaries( u_ambient=ambient_temperature, grad_cv_plus=grad_cv_tpair.ext, grad_t_plus=grad_temperature_tpair.ext, - lengthscales_minus=op.project(dcoll, wall_dd, state_tpair.dd, - wall_lengthscales), + # lengthscales_minus=op.project(dcoll, wall_dd, state_tpair.dd, + # wall_lengthscales), flux_penalty_amount=wall_penalty_amount) for state_tpair, grad_cv_tpair, grad_temperature_tpair in zip( state_inter_volume_trace_pairs[fluid_dd, wall_dd],