Skip to content

Commit

Permalink
Update _get_weights() tests
Browse files Browse the repository at this point in the history
- Check if sum of each weight group equals 1.0
- Update `_get_weights()` docs to remove validation portion
  • Loading branch information
tomvothecoder committed Sep 4, 2024
1 parent 84a7e10 commit e77631e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
44 changes: 44 additions & 0 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,26 @@ def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self):
xr.testing.assert_identical(result, expected)


def _check_each_weight_group_adds_up_to_1(ds: xr.Dataset, weights: xr.DataArray):
"""Check that the sum of the weights in each group adds up to 1.0 (or 100%).
Parameters
----------
ds : xr.Dataset
The dataset with the temporal accessor class attached.
weights : xr.DataArray
The weights to check, produced by the `_get_weights` method.
"""
time_lengths = ds.time_bnds[:, 1] - ds.time_bnds[:, 0]
time_lengths = time_lengths.astype(np.float64)

grouped_time_lengths = ds.temporal._group_data(time_lengths)

actual_sum = ds.temporal._group_data(weights).sum().values
expected_sum = np.ones(len(grouped_time_lengths.groups))
np.testing.assert_allclose(actual_sum, expected_sum)


class Test_GetWeights:
class TestWeightsForAverageMode:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -2289,6 +2309,8 @@ def test_weights_for_yearly_averages(self):
)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_monthly_averages(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -2352,6 +2374,8 @@ def test_weights_for_monthly_averages(self):
)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

class TestWeightsForGroupAverageMode:
@pytest.fixture(autouse=True)
def setup(self):
Expand Down Expand Up @@ -2423,6 +2447,8 @@ def test_weights_for_yearly_averages(self):
)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_monthly_averages(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -2469,6 +2495,8 @@ def test_weights_for_monthly_averages(self):
expected = np.ones(15)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_seasonal_averages_with_DJF_and_drop_incomplete_seasons(
self,
):
Expand Down Expand Up @@ -2631,6 +2659,8 @@ def test_weights_for_seasonal_averages_with_DJF_and_drop_incomplete_seasons(
)
assert np.allclose(result, expected, equal_nan=True)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_seasonal_averages_with_JFD(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -2702,6 +2732,8 @@ def test_weights_for_seasonal_averages_with_JFD(self):
)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_custom_season_time_series_weights(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -2780,6 +2812,8 @@ def test_custom_season_time_series_weights(self):
)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_daily_averages(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -2873,6 +2907,8 @@ def test_weights_for_hourly_averages(self):
expected = np.ones(15)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

class TestWeightsForClimatologyMode:
@pytest.fixture(autouse=True)
def setup(self):
Expand Down Expand Up @@ -3035,6 +3071,8 @@ def test_weights_for_seasonal_climatology_with_DJF(self):

assert np.allclose(result, expected, equal_nan=True)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_seasonal_climatology_with_JFD(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -3101,6 +3139,8 @@ def test_weights_for_seasonal_climatology_with_JFD(self):
)
assert np.allclose(result, expected, equal_nan=True)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_annual_climatology(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -3165,6 +3205,8 @@ def test_weights_for_annual_climatology(self):
)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)

def test_weights_for_daily_climatology(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -3227,6 +3269,8 @@ def test_weights_for_daily_climatology(self):
)
assert np.allclose(result, expected)

_check_each_weight_group_adds_up_to_1(ds, result)


class Test_Averager:
# NOTE: This private method is tested because it is more redundant to
Expand Down
3 changes: 0 additions & 3 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,9 +1221,6 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
divided by the total sum of the time lengths in its group to get its
corresponding weight.
The sum of the weights for each group is validated to ensure it equals
1.0.
Parameters
----------
time_bounds : xr.DataArray
Expand Down

0 comments on commit e77631e

Please sign in to comment.