Skip to content

Commit

Permalink
Remove unused _mask_var_with_weight_threshold() function
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Sep 4, 2024
1 parent bd872cf commit 1e84ee5
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 180 deletions.
135 changes: 1 addition & 134 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import numpy as np
import pytest
import xarray as xr

from tests.fixtures import generate_dataset
from xcdat.utils import (
_validate_min_weight,
compare_datasets,
mask_var_with_weight_threshold,
str_to_bool,
)
from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool


class TestCompareDatasets:
Expand Down Expand Up @@ -112,132 +105,6 @@ def test_raises_error_if_str_is_not_a_python_bool(self):
str_to_bool("1")


class TestMaskVarWithWeightThreshold:
@pytest.fixture(autouse=True)
def setup(self):
self.ds = generate_dataset(
decode_times=True, cf_compliant=False, has_bounds=True
)

def test_returns_mask_var_with_spatial_min_weight_of_100(self):
ds = self.ds.copy()
ds = ds.isel({"time": slice(0, 3), "lat": slice(0, 3), "lon": slice(0, 3)})
ds["ts"][0, :, 2] = np.nan

# Function arguments.
dv = ds["ts"].copy()
weights = ds.spatial.get_weights(
axis=["X", "Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
data_var="ts",
)

result = mask_var_with_weight_threshold(dv, weights, min_weight=1.0)
expected = xr.DataArray(
data=np.array(
[
[
[np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan],
],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
),
coords={"time": ds.time, "lat": ds.lat, "lon": ds.lon},
dims=["time", "lat", "lon"],
)

xr.testing.assert_allclose(result, expected)

def test_returns_mask_var_with_spatial_min_weight_of_0(self):
ds = self.ds.copy()
ds = ds.isel({"time": slice(0, 3), "lat": slice(0, 3), "lon": slice(0, 3)})

# Function arguments.
dv = ds["ts"].copy()
weights = ds.spatial.get_weights(
axis=["X", "Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
data_var="ts",
)

result = mask_var_with_weight_threshold(dv, weights, min_weight=0)
expected = xr.DataArray(
data=np.array(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
),
coords={"time": ds.time, "lat": ds.lat, "lon": ds.lon},
dims=["time", "lat", "lon"],
)

xr.testing.assert_allclose(result, expected)

def test_returns_mask_var_with_temporal_min_weight_of_100(self):
ds = self.ds.copy()
ds = ds.isel({"time": slice(0, 3), "lat": slice(0, 3), "lon": slice(0, 3)})
ds["ts"][0, :, 2] = np.nan

# Function arguments.
dv = ds["ts"].copy()
weights = xr.DataArray(
name="time_wts",
data=np.array([1.0, 1.0, 1.0]),
dims="time",
coords={"time": ds.time},
)

result = mask_var_with_weight_threshold(dv, weights, min_weight=0)
expected = xr.DataArray(
data=np.array(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [np.nan, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [np.nan, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [np.nan, 1.0, 1.0]],
]
),
coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time},
dims=["lat", "lon", "time"],
)

xr.testing.assert_allclose(result, expected)

def test_returns_mask_var_with_temporal_min_weight_of_0(self):
ds = self.ds.copy()
ds = ds.isel({"time": slice(0, 3), "lat": slice(0, 3), "lon": slice(0, 3)})

# Function arguments.
dv = ds["ts"].copy()
weights = xr.DataArray(
name="time_wts",
data=np.array([1.0, 1.0, 1.0]),
dims="time",
coords={"time": ds.time},
)

result = mask_var_with_weight_threshold(dv, weights, min_weight=0)
expected = xr.DataArray(
data=np.array(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
),
coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time},
dims=["lat", "lon", "time"],
)

xr.testing.assert_allclose(result, expected)


class TestValidateMinWeight:
def test_pass_None_returns_0(self):
result = _validate_min_weight(None)
Expand Down
46 changes: 0 additions & 46 deletions xcdat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
from typing import Dict, List, Optional, Union

import numpy as np
import xarray as xr
from dask.array.core import Array

Expand Down Expand Up @@ -135,51 +134,6 @@ def _if_multidim_dask_array_then_load(
return None


def mask_var_with_weight_threshold(
dv: xr.DataArray, weights: xr.DataArray, min_weight: float
) -> xr.DataArray:
"""Mask values that do not meet the minimum weight threshold using np.nan.
This function is useful for cases where the weighting of data might be
skewed based on the availability of data. For example, if one season in a
time series has more significantly more missing data than other seasons, it
can result in inaccurate calculations of climatologies. Masking values that
do not meet the minimum weight threshold ensures more accurate calculations.
Parameters
----------
dv : xr.DataArray
The weighted variable.
weights : xr.DataArray
A DataArray containing either the regional or temporal weights used for
weighted averaging. ``weights`` must include the same axis dimensions
and dimensional sizes as the data variable.
min_weight : float
Fraction of data coverage (i..e, weight) needed to return a
spatial average value. Value must range from 0 to 1.
Returns
-------
xr.DataArray
The variable with the minimum weight threshold applied.
"""
masked_weights = _get_masked_weights(dv, weights)

# Sum all weights, including zero for missing values.
dim = weights.dims
weight_sum_all = weights.sum(dim=dim)
weight_sum_masked = masked_weights.sum(dim=dim)

# Get fraction of the available weight.
frac = weight_sum_masked / weight_sum_all

# Nan out values that don't meet specified weight threshold.
dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True)
dv_new.name = dv.name

return dv_new


def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray:
"""Get weights with missing data (`np.nan`) receiving no weight (zero).
Expand Down

0 comments on commit 1e84ee5

Please sign in to comment.