Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add coherent artifact plot functionality #123

Merged
merged 5 commits into from
Jan 1, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyglotaran_extras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pyglotaran extension package with convenience functionality such as plotting."""
from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.io.setup_case_study import setup_case_study
from pyglotaran_extras.plotting.plot_coherent_artifact import plot_coherent_artifact
from pyglotaran_extras.plotting.plot_data import plot_data_overview
from pyglotaran_extras.plotting.plot_guidance import plot_guidance
from pyglotaran_extras.plotting.plot_irf_dispersion_center import plot_irf_dispersion_center
Expand All @@ -12,6 +13,7 @@
__all__ = [
"load_data",
"setup_case_study",
"plot_coherent_artifact",
"plot_data_overview",
"plot_overview",
"plot_simple_overview",
Expand Down
123 changes: 123 additions & 0 deletions pyglotaran_extras/plotting/plot_coherent_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Module containing coherent artifact plot functionality."""
from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from cycler import Cycler

from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none

if TYPE_CHECKING:
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes


def plot_coherent_artifact(
res: xr.Dataset,
*,
time_range: tuple[float, float] | None = None,
spectral: float = 0,
normalize: bool = False,
figsize: tuple[int, int] = (18, 7),
show_zero_line: bool = True,
cycler: Cycler | None = None,
title: str | None = "Coherent Artifact",
) -> tuple[Figure, Axes]:
"""Plot coherent artifact as IRF derivative components over time and IRFAS over spectral dim.

The IRFAS are the IRF (Instrument Response Function) Associated Spectra.

Parameters
----------
res: xr.Dataset
Result dataset from a pyglotaran optimization.
time_range: tuple[float, float] | None
Start and end time for the IRF derivative plot. Defaults to None which means that
the full time range is used
s-weigand marked this conversation as resolved.
Show resolved Hide resolved
spectral: float
Value of the spectral axis that should be used to select the data for the IRF derivative
plot this value does not need to be an exact existing value and only has a effect if the
s-weigand marked this conversation as resolved.
Show resolved Hide resolved
IRF has dispersion. Defaults to 0 which means that the IRF derivative plot at lowest
spectral value will be shown.
normalize: bool
Whether or not to normalize the IRF derivative plot.If the IRF derivative is normalized,
s-weigand marked this conversation as resolved.
Show resolved Hide resolved
the IRFAS is scaled with the reciprocal of the normalization to compensate for this.
Defaults to False.
figsize: tuple[int, int]
Size of the figure (N, M) in inches. Defaults to (18, 7)
s-weigand marked this conversation as resolved.
Show resolved Hide resolved
show_zero_line: bool
Whether or not to add a horizontal line at zero. Defaults to True.
cycler: Cycler | None
Plot style cycler to use. Defaults to None, which means that the matplotlib default style
will be used.
title: str | None
Title of the figure. Defaults to "Coherent Artifact".

Returns
-------
tuple[Figure, Axes]
Figure object which contains the plots and the Axes.
"""
fig, axes = plt.subplots(1, 2, figsize=figsize)
add_cycler_if_not_none(axes, cycler)

if (
"coherent_artifact_response" not in res
or "coherent_artifact_associated_spectra" not in res
):
warn(
UserWarning(f"Dataset does not contain coherent artifact data:\n {res.data_vars}"),
stacklevel=2,
)
return fig, axes

irf_max = abs_max(res.coherent_artifact_response, result_dims=("coherent_artifact_order"))
irfas_max = abs_max(
res.coherent_artifact_associated_spectra, result_dims=("coherent_artifact_order")
)
scales = np.sqrt(irfas_max * irf_max)
norm_factor = 1
irf_y_label = "amplitude"
irfas_y_label = "ΔA"

if normalize is True:
norm_factor = scales.max()
irf_y_label = f"normalized {irf_y_label}"

plot_slice_irf = (
res.coherent_artifact_response.sel(spectral=spectral, method="nearest")
/ irf_max
* scales
/ norm_factor
)
irf_sel_kwargs = (
{"time": slice(time_range[0], time_range[1])} if time_range is not None else {}
)
plot_slice_irf.sel(**irf_sel_kwargs).plot.line(x="time", ax=axes[0])
axes[0].set_title("IRF Derivatives")
axes[0].set_ylabel(f"{irf_y_label} (a.u.)")

plot_slice_irfas = res.coherent_artifact_associated_spectra / irfas_max * scales * norm_factor
plot_slice_irfas.plot.line(x="spectral", ax=axes[1])
axes[1].get_legend().remove()
axes[1].set_title("IRFAS")
axes[1].set_ylabel(f"{irfas_y_label} (mOD)")

if show_zero_line is True:
axes[0].axhline(0, color="k", linewidth=1)
axes[1].axhline(0, color="k", linewidth=1)

#
if res.coords["coherent_artifact_order"][0] == 1:
axes[0].legend(
[f"{int(ax_label)-1}" for ax_label in res.coords["coherent_artifact_order"]],
title="coherent_artifact_order",
)
if title:
fig.suptitle(title, fontsize=16)
return fig, axes
39 changes: 34 additions & 5 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Iterable
from warnings import warn

import numpy as np
Expand All @@ -10,7 +11,7 @@
from pyglotaran_extras.io.utils import result_dataset_mapping

if TYPE_CHECKING:
from typing import Iterable
from typing import Hashable

from cycler import Cycler
from matplotlib.axis import Axis
Expand Down Expand Up @@ -360,7 +361,7 @@ def get_shifted_traces(
return shift_time_axis_by_irf_location(traces, irf_location)


def add_cycler_if_not_none(axis: Axis, cycler: Cycler | None) -> None:
def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None:
"""Add cycler to and axis if it is not None.

This is a convenience function that allow to opt out of using
Expand All @@ -370,10 +371,38 @@ def add_cycler_if_not_none(axis: Axis, cycler: Cycler | None) -> None:

Parameters
----------
axis: Axis
Axis to plot the data and fits on.
axis: Axis | Axes
Axis to plot on.
cycler: Cycler | None
Plot style cycler to use.
"""
if cycler is not None:
axis.set_prop_cycle(cycler)
# We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes`
if hasattr(axis, "flatten") is False:
axis = np.array([axis])
for ax in axis.flatten():
ax.set_prop_cycle(cycler)


def abs_max(
data: xr.DataArray, *, result_dims: Hashable | Iterable[Hashable] = ()
) -> xr.DataArray:
"""Calculate the absolute maximum values of ``data`` along all dims except ``result_dims``.

Parameters
----------
data: xr.DataArray
Data for which the absolute maximum should be calculated.
result_dims: Hashable | Iterable[Hashable]
Dimensions of ``data`` which should be preserved and part of the resulting DataArray.
Defaults to () which results in the absolute maximum of all values
s-weigand marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
xr.DataArray
Absolute maximum values of ``data`` with dimensions ``result_dims``.
"""
if not isinstance(result_dims, Iterable):
result_dims = (result_dims,)
reduce_dims = (dim for dim in data.dims if dim not in result_dims)
return np.abs(data).max(dim=reduce_dims)
46 changes: 43 additions & 3 deletions tests/plotting/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Tests for pyglotaran_extras.plotting.utils"""
from __future__ import annotations

from typing import Hashable
from typing import Iterable

import matplotlib
import matplotlib.pyplot as plt
import pytest
import xarray as xr
from cycler import Cycler
from cycler import cycle

from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none

matplotlib.use("Agg")
Expand All @@ -18,10 +23,45 @@
"cycler,expected_cycler",
((None, DEFAULT_CYCLER()), (PlotStyle().cycler, PlotStyle().cycler())),
)
def test_add_cycler_if_not_none(cycler: Cycler | None, expected_cycler: cycle):
"""Default cycler inf None and cycler otherwise"""
def test_add_cycler_if_not_none_single_axis(cycler: Cycler | None, expected_cycler: cycle):
"""Default cycler if None and cycler otherwise on a single axis"""
ax = plt.subplot()
add_cycler_if_not_none(ax, cycler)

for _ in range(10):
assert next(ax._get_lines.prop_cycler) == next(expected_cycler)
expected = next(expected_cycler)
assert next(ax._get_lines.prop_cycler) == expected


@pytest.mark.parametrize(
"cycler,expected_cycler",
((None, DEFAULT_CYCLER()), (PlotStyle().cycler, PlotStyle().cycler())),
)
def test_add_cycler_if_not_none_multiple_axes(cycler: Cycler | None, expected_cycler: cycle):
"""Default cycler if None and cycler otherwise on all axes"""
_, axes = plt.subplots(1, 2)
add_cycler_if_not_none(axes, cycler)

for _ in range(10):
expected = next(expected_cycler)
assert next(axes[0]._get_lines.prop_cycler) == expected
assert next(axes[1]._get_lines.prop_cycler) == expected


@pytest.mark.parametrize(
"result_dims, expected",
(
((), xr.DataArray(40)),
("dim1", xr.DataArray([20, 40], coords={"dim1": [1, 2]})),
("dim2", xr.DataArray([30, 40], coords={"dim2": [3, 4]})),
(("dim1",), xr.DataArray([20, 40], coords={"dim1": [1, 2]})),
(
("dim1", "dim2"),
xr.DataArray([[10, 20], [30, 40]], coords={"dim1": [1, 2], "dim2": [3, 4]}),
),
),
)
def test_abs_max(result_dims: Hashable | Iterable[Hashable], expected: xr.DataArray):
"""Result values are positive and dimensions are preserved if result_dims is not empty."""
data = xr.DataArray([[-10, 20], [-30, 40]], coords={"dim1": [1, 2], "dim2": [3, 4]})
assert abs_max(data, result_dims=result_dims).equals(expected)