Skip to content

Commit

Permalink
added setting class for time domain simulations
Browse files Browse the repository at this point in the history
  • Loading branch information
astanziola committed Nov 24, 2023
1 parent e1cb452 commit cda7cb9
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 75 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

### Changed
- `Medium` objects are now `JaxDFModule`s, which is based on `equinox` modules. It is also a [parametric module for dispatching operators](https://beartype.github.io/plum/parametric.html), meaning that there's a type difference betwee `Medium[FourierSeries]` and `Medium[FiniteDifferences]`, for example.
- The settings of time domain acoustic simulations are now set using a `TimeWavePropagationSettings`. This also includes an attribute to explicity set the reference sound speed.

### Added
- Added a logger in `jwave.logger`
Expand Down
163 changes: 106 additions & 57 deletions jwave/acoustics/time_varying.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
# You should have received a copy of the GNU Lesser General Public
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.

from typing import Dict, Tuple, TypeVar, Union
from typing import Callable, Dict, Tuple, TypeVar, Union

import equinox as eqx
import numpy as np
from jax import checkpoint as jax_checkpoint
from jax import numpy as jnp
from jax.lax import scan
from jaxdf import Field, operator
from jaxdf.discretization import FourierSeries, Linear, OnGrid
from jaxdf.operators import (diag_jacobian, functional, shift_operator,
sum_over_dims)
from jaxdf.mods import JaxDFModule
from jaxdf.operators import diag_jacobian, shift_operator, sum_over_dims

from jwave.acoustics.spectral import kspace_op
from jwave.geometry import Medium, Sources, TimeAxis
Expand All @@ -34,6 +35,55 @@
Any = TypeVar("Any")


class TimeWavePropagationSettings(JaxDFModule):
"""
TimeWavePropagationSettings configures the settings for
time domain wave solvers. This class serves as a container
for settings that influence how wave propagation is
simulated.
!!! example
```python
>>> settings = TimeWavePropagationSettings(
... c_ref = lambda m: m.min_sound_speed)
>>> print(settings.checkpoint)
True
```
"""

c_ref: Callable = eqx.field(static=True)
checkpoint: bool = eqx.field(static=True)
smooth_initial: bool = eqx.field(static=True)

def __init__(
self,
c_ref: Callable = lambda m: m.max_sound_speed,
checkpoint: bool = True,
smooth_initial: bool = True,
):
"""
Initializes a new instance of the TimeWavePropagationSettings class.
Args:
c_ref (Callable, static): A callable that determines
the reference speed of the wave solver. This is a
expected to be a function that takes the `medium`
variable and returns the reference sound speed
checkpoint (bool, static): Flag indicating whether to
use checkpointing to save memory during backpropagation.
Defaults to True.
smooth_initial (bool, static): Flag to determine
whether to smooth initial pressure and velocity
fields. Defaults to True.
"""
self.c_ref = c_ref
self.checkpoint = checkpoint
self.smooth_initial = smooth_initial


default_time_wave_prop_settings = TimeWavePropagationSettings()


def _shift_rho(rho0, direction, dx):
if isinstance(rho0, OnGrid):
rho0_params = rho0.params[..., 0]
Expand Down Expand Up @@ -250,40 +300,6 @@ def pressure_from_density(rho: Field, medium: Medium, *, params=None) -> Field:
return (c0**2) * rho_sum


def ongrid_wave_prop_params(
medium: OnGrid,
time_axis: TimeAxis,
*args,
**kwargs,
):
# Check which elements of medium are a field
x = [
x for x in [medium.sound_speed, medium.density, medium.attenuation]
if isinstance(x, Field)
][0]

dt = time_axis.dt
c_ref = functional(medium.sound_speed)(jnp.amax)

# Making PML on grid for rho and u
def make_pml(staggering=0.0):
pml_grid = td_pml_on_grid(medium,
dt,
c0=c_ref,
dx=medium.domain.dx[0],
coord_shift=staggering)
pml = x.replace_params(pml_grid)
return pml

pml_rho = make_pml()
pml_u = make_pml(staggering=0.5)

return {
"pml_rho": pml_rho,
"pml_u": pml_u,
}


@operator
def wave_propagation_symplectic_step(
p: Linear,
Expand Down Expand Up @@ -322,18 +338,54 @@ def wave_propagation_symplectic_step(
return [p, u, rho]


@operator
def ongrid_wave_prop_params(
medium: OnGrid,
time_axis: TimeAxis,
*,
settings: TimeWavePropagationSettings,
**kwargs,
):
# Check which elements of medium are a field
x = [
x for x in [medium.sound_speed, medium.density, medium.attenuation]
if isinstance(x, Field)
][0]

dt = time_axis.dt

# Use settings to determine reference sound speed
c_ref = settings.c_ref(medium)

# Making PML on grid for rho and u
def make_pml(staggering=0.0):
pml_grid = td_pml_on_grid(medium,
dt,
c0=c_ref,
dx=medium.domain.dx[0],
coord_shift=staggering)
pml = x.replace_params(pml_grid)
return pml

pml_rho = make_pml()
pml_u = make_pml(staggering=0.5)

return {
"pml_rho": pml_rho,
"pml_u": pml_u,
"c_ref": c_ref,
}


@operator(init_params=ongrid_wave_prop_params)
def simulate_wave_propagation(
medium: Medium[OnGrid],
time_axis: TimeAxis,
*,
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
sources=None,
sensors=None,
u0=None,
p0=None,
checkpoint: bool = True,
max_unroll_checkpoint: int = 10,
smooth_initial=True,
params=None,
):
r"""Simulate the wave propagation operator.
Expand Down Expand Up @@ -371,12 +423,9 @@ def simulate_wave_propagation(
# Setup parameters
output_steps = jnp.arange(0, time_axis.Nt, 1)
dt = time_axis.dt
c_ref = functional(medium.sound_speed)(jnp.amax)

if params == None:
params = ongrid_wave_prop_params(medium, time_axis)

# Get parameters
c_ref = params["c_ref"]
pml_rho = params["pml_rho"]
pml_u = params["pml_u"]

Expand All @@ -394,7 +443,7 @@ def simulate_wave_propagation(
if p0 is None:
p0 = pml_rho.replace_params(jnp.zeros(shape_one))
else:
if smooth_initial:
if settings.smooth_initial:
p0_params = p0.params[..., 0]
p0_params = jnp.expand_dims(smooth(p0_params), -1)
p0 = p0.replace_params(p0_params)
Expand Down Expand Up @@ -433,7 +482,7 @@ def scan_fun(fields, n):
p = pressure_from_density(rho, medium)
return [p, u, rho], sensors(p, u, rho)

if checkpoint:
if settings.checkpoint:
scan_fun = jax_checkpoint(scan_fun)

logger.debug("Starting simulation using generic OnGrid code")
Expand All @@ -445,11 +494,14 @@ def scan_fun(fields, n):
def fourier_wave_prop_params(
medium: Medium[FourierSeries],
time_axis: TimeAxis,
*args,
*,
settings: TimeWavePropagationSettings,
**kwargs,
):
dt = time_axis.dt
c_ref = functional(medium.sound_speed)(jnp.amax)

# Use settings to determine reference sound speed
c_ref = settings.c_ref(medium)

# Making PML on grid for rho and u
def make_pml(staggering=0.0):
Expand All @@ -471,6 +523,7 @@ def make_pml(staggering=0.0):
"pml_rho": pml_rho,
"pml_u": pml_u,
"fourier": fourier,
"c_ref": c_ref
}


Expand All @@ -479,13 +532,11 @@ def simulate_wave_propagation(
medium: Medium[FourierSeries],
time_axis: TimeAxis,
*,
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
sources=None,
sensors=None,
u0=None,
p0=None,
checkpoint: bool = True,
max_unroll_checkpoint: int = 10,
smooth_initial=True,
params=None,
):
r"""Simulates the wave propagation operator using the PSTD method. This
Expand Down Expand Up @@ -524,11 +575,9 @@ def simulate_wave_propagation(
# Setup parameters
output_steps = jnp.arange(0, time_axis.Nt, 1)
dt = time_axis.dt
c_ref = functional(medium.sound_speed)(jnp.amax)
if params == None:
params = fourier_wave_prop_params(medium, time_axis)

# Get parameters
c_ref = params["c_ref"]
pml_rho = params["pml_rho"]
pml_u = params["pml_u"]

Expand All @@ -546,7 +595,7 @@ def simulate_wave_propagation(
if p0 is None:
p0 = pml_rho.replace_params(jnp.zeros(shape_one))
else:
if smooth_initial:
if settings.smooth_initial:
p0_params = p0.params[..., 0]
p0_params = jnp.expand_dims(smooth(p0_params), -1)
p0 = p0.replace_params(p0_params)
Expand Down Expand Up @@ -592,7 +641,7 @@ def scan_fun(fields, n):
return [p, u, rho], sensors(p, u, rho)

# Define the scanning function according to the checkpoint type
if checkpoint:
if settings.checkpoint:
scan_fun = jax_checkpoint(scan_fun)

logger.debug("Starting simulation using FourierSeries code")
Expand Down
78 changes: 78 additions & 0 deletions jwave/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,84 @@ def __init_type_parameter__(self, t: type):
f"The type parameter of a Medium object must be a subclass of Field. Got {t}"
)

@property
def max_sound_speed(self):
"""
Calculate and return the maximum sound speed.
This property uses the `sound_speed` method/function and applies the `amax`
function from JAX's numpy (jnp) library to find the maximum sound speed value.
Returns:
The maximum sound speed value.
"""
return functional(self.sound_speed)(jnp.amax)

@property
def min_sound_speed(self):
"""
Calculate and return the minimum sound speed.
This property uses the `sound_speed` method/function and applies the `amin`
function from JAX's numpy (jnp) library to find the minimum sound speed value.
Returns:
The minimum sound speed value.
"""
return functional(self.sound_speed)(jnp.amin)

@property
def max_density(self):
"""
Calculate and return the maximum density.
This property uses the `density` method/function and applies the `amax`
function from JAX's numpy (jnp) library to find the maximum density value.
Returns:
The maximum density value.
"""
return functional(self.density)(jnp.amax)

@property
def min_density(self):
"""
Calculate and return the minimum density.
This property uses the `density` method/function and applies the `amin`
function from JAX's numpy (jnp) library to find the minimum density value.
Returns:
The minimum density value.
"""
return functional(self.density)(jnp.amin)

@property
def max_attenuation(self):
"""
Calculate and return the maximum attenuation.
This property uses the `attenuation` method/function and applies the `amax`
function from JAX's numpy (jnp) library to find the maximum attenuation value.
Returns:
The maximum attenuation value.
"""
return functional(self.attenuation)(jnp.amax)

@property
def min_attenuation(self):
"""
Calculate and return the minimum attenuation.
This property uses the `attenuation` method/function and applies the `amin`
function from JAX's numpy (jnp) library to find the minimum attenuation value.
Returns:
The minimum attenuation value.
"""
return functional(self.attenuation)(jnp.amin)

@classmethod
def __infer_type_parameter__(self, *args, **kwargs):
"""Inter the type parameter from the arguments. Defaults to FourierSeries if
Expand Down
Loading

0 comments on commit cda7cb9

Please sign in to comment.