Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Element Result Saving #1534

Open
wants to merge 19 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 11 additions & 6 deletions glotaran/builtin/elements/baseline/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from glotaran.model.element import Element
from glotaran.model.element import ElementResult

if TYPE_CHECKING:
import xarray as xr
Expand Down Expand Up @@ -34,11 +35,15 @@ def calculate_matrix(
matrix = np.ones((model_axis.size, 1), dtype=np.float64)
return clp_label, matrix

def add_to_result_data(
def create_result(
self,
model: DataModel,
data: xr.Dataset,
as_global: bool = False,
):
if not as_global:
data["baseline"] = data.clp.sel(clp_label=self.clp_label())
global_dimension: str,
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
return ElementResult(
amplitudes={"baseline": amplitudes.sel(amplitude_label=self.clp_label())},
concentrations={},
)
171 changes: 67 additions & 104 deletions glotaran/builtin/elements/coherent_artifact/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@

import numba as nb
import numpy as np
import xarray as xr

from glotaran.builtin.items.activation import ActivationDataModel
from glotaran.builtin.items.activation import MultiGaussianActivation
from glotaran.builtin.items.activation import add_activation_to_result_data
from glotaran.model.data_model import DataModel # noqa: TCH001
from glotaran.model.element import Element
from glotaran.model.element import ElementResult
from glotaran.model.errors import GlotaranModelError
from glotaran.model.item import ParameterType # noqa: TCH001

if TYPE_CHECKING:
import xarray as xr

from glotaran.model.data_model import DataModel
from glotaran.typing.types import ArrayLike


Expand All @@ -37,117 +38,79 @@ def calculate_matrix( # type:ignore[override]
if not 1 <= self.order <= 3:
raise GlotaranModelError("Coherent artifact order must be between in [1,3]")

activations = [a for a in model.activation if isinstance(a, MultiGaussianActivation)]

matrices = []
activation_indices = []
for i, activation in enumerate(activations):
if self.label not in activation.compartments:
continue
activation_indices.append(i)
parameters = activation.parameters(global_axis)

matrix_shape = (model_axis.size, self.order)
index_dependent = any(isinstance(p, list) for p in parameters)
if index_dependent:
matrix_shape = (global_axis.size, *matrix_shape) # type:ignore[assignment]
matrix = np.zeros(matrix_shape, dtype=np.float64)
if index_dependent:
_calculate_coherent_artifact_matrix(
matrix,
np.array([ps[0].center for ps in parameters]), # type:ignore[index]
np.array(
[self.width or ps[0].width for ps in parameters] # type:ignore[index]
),
global_axis.size,
model_axis,
self.order,
)

else:
_calculate_coherent_artifact_matrix_on_index(
matrix,
parameters[0].center, # type:ignore[union-attr]
self.width or parameters[0].width, # type:ignore[union-attr]
model_axis,
self.order,
)
matrix *= activation.compartments[self.label] # type:ignore[arg-type]
matrices.append(matrix)

if not len(matrices):
activations = [
a
for a in model.activation
if isinstance(a, MultiGaussianActivation) and self.label in a.compartments
]

if not len(activations):
raise GlotaranModelError(
f'No (multi-)gaussian activation for coherent-artifact "{self.label}".'
)
if len(activations) > 1:
raise GlotaranModelError(
f'Coherent artifact "{self.label}" must be associated with exactly one activation.'
)
activation = activations[0]

parameters = activation.parameters(global_axis)

matrix_shape = (model_axis.size, self.order)
index_dependent = any(isinstance(p, list) for p in parameters)
if index_dependent:
matrix_shape = (global_axis.size, *matrix_shape) # type:ignore[assignment]
matrix = np.zeros(matrix_shape, dtype=np.float64)
if index_dependent:
_calculate_coherent_artifact_matrix(
matrix,
np.array([ps[0].center for ps in parameters]), # type:ignore[index]
np.array(
[self.width or ps[0].width for ps in parameters] # type:ignore[index]
),
global_axis.size,
model_axis,
self.order,
)

else:
_calculate_coherent_artifact_matrix_on_index(
matrix,
parameters[0].center, # type:ignore[union-attr]
self.width or parameters[0].width, # type:ignore[union-attr]
model_axis,
self.order,
)
matrix *= activation.compartments[self.label] # type:ignore[arg-type]

clp_axis = []
for i in activation_indices:
clp_axis += [f"{label}_activation_{i}" for label in self.compartments()]
return clp_axis, np.concatenate(matrices, axis=len(matrices[0].shape) - 1)
return self.compartments, matrix

@property
def compartments(self):
return [f"coherent_artifact_{self.label}_order_{i}" for i in range(1, self.order + 1)]

def add_to_result_data( # type:ignore[override]
def create_result(
self,
model: ActivationDataModel,
data: xr.Dataset,
as_global: bool = False,
):
add_activation_to_result_data(model, data)
if "coherent_artifact_order" in data.coords:
return

data_matrix = data.global_matrix if "global_matrix" in data else data.matrix
elements = [m for m in model.elements if isinstance(m, CoherentArtifactElement)]
nr_activations = data.gaussian_activation.size
matrices = []
estimations = []
for coherent_artifact in elements:
artifact_matrices = []
artifact_estimations = []
activation_indices = []
for i in range(nr_activations):
clp_axis = [
label
for label in data.clp_label.data
if label.startswith(f"coherent_artifact_{coherent_artifact.label}")
and label.endswith(f"_activation_{i}")
]
if not len(clp_axis):
continue
activation_indices.append(i)
order = [label.split("_activation_")[0].split("_order")[1] for label in clp_axis]

artifact_matrices.append(
data_matrix.sel(clp_label=clp_axis)
.rename(clp_label="coherent_artifact_order")
.assign_coords({"coherent_artifact_order": order})
)
if "global_matrix" not in data:
artifact_estimations.append(
data.clp.sel(clp_label=clp_axis)
.rename(clp_label="coherent_artifact_order")
.assign_coords({"coherent_artifact_order": order})
)
matrices.append(
xr.concat(artifact_matrices, dim="gaussian_activation").assign_coords(
{"gaussian_activation": activation_indices}
)
)
if "global_matrix" not in data:
estimations.append(
xr.concat(artifact_estimations, dim="gaussian_activation").assign_coords(
{"gaussian_activation": activation_indices}
)
)
data["coherent_artifact_response"] = xr.concat(
matrices, dim="coherent_artifact"
).assign_coords({"coherent_artifact": [m.label for m in elements]})
if "global_matrix" not in data:
data["coherent_artifact_associated_estimation"] = xr.concat(
estimations, dim="coherent_artifact"
).assign_coords({"coherent_artifact": [m.label for m in elements]})
global_dimension: str,
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
amplitude = (
amplitudes.sel(amplitude_label=self.compartments)
.rename(amplitude_label="coherent_artifact_order")
.assign_coords({"coherent_artifact_order": range(1, self.order + 1)})
)
concentration = (
concentrations.sel(amplitude_label=self.compartments)
.rename(amplitude_label="coherent_artifact_order")
.assign_coords({"coherent_artifact_order": range(1, self.order + 1)})
)
return ElementResult(
amplitudes={"coherent_artifact": amplitude},
concentrations={"coherent_artifact": concentration},
)


@nb.jit(nopython=True, parallel=False)
Expand Down
Loading
Loading