Skip to content

Commit

Permalink
Api update (#227)
Browse files Browse the repository at this point in the history
* Revert "Release 0.2.0"
This reverts commit 060647c.
* updated public api
* removed unused second outputs from operators
  • Loading branch information
astanziola authored Dec 18, 2023
1 parent 060647c commit 7c2381d
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 21 deletions.
6 changes: 1 addition & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

## [0.2.0] - 2023-12-18
### Fixed
- Fixed arguments error in helmholtz notebook

Expand Down Expand Up @@ -99,8 +97,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Pml for 1D and 3D simulations.
- Plotting functions of `jwave.utils` now work with both `Field`s and arrays.

[Unreleased]: https://github.com/ucl-bug/jwave/compare/0.2.0...master
[0.2.0]: https://github.com/ucl-bug/jwave/compare/0.1.5...0.2.0
[Unreleased]: https://github.com/ucl-bug/jwave/compare/0.1.5...master
[0.1.5]: https://github.com/ucl-bug/jwave/compare/0.1.4...0.1.5
[0.1.4]: https://github.com/ucl-bug/jwave/compare/0.1.3...0.1.4
[0.1.3]: https://github.com/ucl-bug/jwave/compare/0.1.2...0.1.3
Expand All @@ -111,4 +108,3 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
[0.0.3]: https://github.com/ucl-bug/jwave/compare/0.0.2...0.0.3
[0.0.2]: https://github.com/ucl-bug/jwave/compare/0.0.1...0.0.2
[0.0.1]: https://github.com/ucl-bug/jwave/releases/tag/0.0.1

46 changes: 45 additions & 1 deletion jwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,53 @@
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.

# nopycln: file
from jaxdf.discretization import *
from jaxdf import (
operator,
Continuous,
Domain,
FiniteDifferences,
FourierSeries,
Field,
Linear,
OnGrid
)

from .acoustics import (
angular_spectrum,
born_iteration,
born_series,
db2neper,
helmholtz_solver_verbose,
helmholtz_solver,
helmholtz,
homogeneous_helmholtz_green,
laplacian_with_pml,
mass_conservation_rhs,
momentum_conservation_rhs,
pml,
pressure_from_density,
rayleigh_integral,
scale_source_helmholtz,
scattering_potential,
simulate_wave_propagation,
spectral,
wave_propagation_symplectic_step,
wavevector,
TimeWavePropagationSettings,
)
from .geometry import (
BLISensors,
DistributedTransducer,
Medium,
Sensors,
Sources,
TimeAxis,
TimeHarmonicSource,
)

from jwave import acoustics as ac
from jwave import geometry as geometry
from jwave import logger as logger
from jwave import phantoms as phantoms
from jwave import signal_processing as signal_processing
from jwave import utils as utils
31 changes: 28 additions & 3 deletions jwave/acoustics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.

# nopycln: file
from .operators import *
from .time_harmonic import *
from .time_varying import *
from .conversion import db2neper
from .operators import (
helmholtz,
laplacian_with_pml,
scale_source_helmholtz,
wavevector,
)
from .time_harmonic import (
angular_spectrum,
born_iteration,
born_series,
helmholtz_solver,
helmholtz_solver_verbose,
homogeneous_helmholtz_green,
rayleigh_integral,
scattering_potential
)
from .time_varying import (
mass_conservation_rhs,
momentum_conservation_rhs,
pressure_from_density,
simulate_wave_propagation,
wave_propagation_symplectic_step,
TimeWavePropagationSettings,
)

from . import spectral
from . import pml
12 changes: 6 additions & 6 deletions jwave/acoustics/time_harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def body_fun(carry):

out_field = _cbs_unnorm_units(out_field, _conversion)

return out_field, None
return out_field


@operator
Expand Down Expand Up @@ -377,7 +377,7 @@ def born_iteration(field: Field,
G = homogeneous_helmholtz_green(V1 + src, k0=k0, epsilon=epsilon)
V2 = scattering_potential(field - G, k_sq, k0=k0, epsilon=epsilon)

return field - (1j / epsilon) * V2, params
return field - (1j / epsilon) * V2


@operator
Expand All @@ -401,7 +401,7 @@ def scattering_potential(field: Field,

k = k_sq - k0**2 - 1j * epsilon
out = field * k
return out, params
return out


@operator
Expand Down Expand Up @@ -430,7 +430,7 @@ def homogeneous_helmholtz_green(field: FourierSeries,
u_fft = jnp.fft.fftn(u)
Gu_fft = g_fourier * u_fft
Gu = jnp.fft.ifftn(Gu_fft)
return field.replace_params(Gu), params
return field.replace_params(Gu)


@operator
Expand Down Expand Up @@ -500,7 +500,7 @@ def direc_exp_term(x, y, z):
# Weights of the Rayleigh integral
weights = jax.vmap(jax.vmap(direc_exp_term, in_axes=(0, 0, 0)),
in_axes=(0, 0, 0))(R[..., 0], R[..., 1], R[..., 2])
return jnp.sum(weights * pressure.on_grid) * area, None
return jnp.sum(weights * pressure.on_grid) * area


@operator
Expand Down Expand Up @@ -560,7 +560,7 @@ def helm_func(u):
)[0]
elif method == "bicgstab":
out = bicgstab(helm_func, source, guess, tol=tol, maxiter=maxiter)[0]
return -1j * omega * out, None
return -1j * omega * out


def helmholtz_solver_verbose(
Expand Down
6 changes: 2 additions & 4 deletions jwave/acoustics/time_varying.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def __init__(
self.smooth_initial = smooth_initial


default_time_wave_prop_settings = TimeWavePropagationSettings()


def _shift_rho(rho0, direction, dx):
if isinstance(rho0, OnGrid):
Expand Down Expand Up @@ -382,7 +380,7 @@ def simulate_wave_propagation(
medium: Medium[OnGrid],
time_axis: TimeAxis,
*,
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
Expand Down Expand Up @@ -533,7 +531,7 @@ def simulate_wave_propagation(
medium: Medium[FourierSeries],
time_axis: TimeAxis,
*,
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jwave"
version = "0.2.0"
version = "0.1.5"
description = "Fast and differentiable acoustic simulations in JAX."
authors = [
"Antonio Stanziola <[email protected]>",
Expand Down Expand Up @@ -108,9 +108,12 @@ split_before_logical_operator = true

[tool.pytest.ini_options]
addopts = """\
--doctest-modules \
--doctest-modules\
"""

[tool.pytest_env]
CUDA_VISIBLE_DEVICES = ""

[tool.coverage.report]
exclude_lines = [
'if TYPE_CHECKING:',
Expand Down

0 comments on commit 7c2381d

Please sign in to comment.