From e77631eb93092ca8502ca902ee013d23dae03ec4 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Wed, 4 Sep 2024 12:20:39 -0700 Subject: [PATCH] Update `_get_weights()` tests - Check if sum of each weight group equals 1.0 - Update `_get_weights()` docs to remove validation portion --- tests/test_temporal.py | 44 ++++++++++++++++++++++++++++++++++++++++++ xcdat/temporal.py | 3 --- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 4385cfab..e5489b1b 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -2217,6 +2217,26 @@ def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self): xr.testing.assert_identical(result, expected) +def _check_each_weight_group_adds_up_to_1(ds: xr.Dataset, weights: xr.DataArray): + """Check that the sum of the weights in each group adds up to 1.0 (or 100%). + + Parameters + ---------- + ds : xr.Dataset + The dataset with the temporal accessor class attached. + weights : xr.DataArray + The weights to check, produced by the `_get_weights` method. + """ + time_lengths = ds.time_bnds[:, 1] - ds.time_bnds[:, 0] + time_lengths = time_lengths.astype(np.float64) + + grouped_time_lengths = ds.temporal._group_data(time_lengths) + + actual_sum = ds.temporal._group_data(weights).sum().values + expected_sum = np.ones(len(grouped_time_lengths.groups)) + np.testing.assert_allclose(actual_sum, expected_sum) + + class Test_GetWeights: class TestWeightsForAverageMode: @pytest.fixture(autouse=True) @@ -2289,6 +2309,8 @@ def test_weights_for_yearly_averages(self): ) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_monthly_averages(self): ds = self.ds.copy() @@ -2352,6 +2374,8 @@ def test_weights_for_monthly_averages(self): ) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + class TestWeightsForGroupAverageMode: @pytest.fixture(autouse=True) def setup(self): @@ -2423,6 +2447,8 @@ def test_weights_for_yearly_averages(self): ) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_monthly_averages(self): ds = self.ds.copy() @@ -2469,6 +2495,8 @@ def test_weights_for_monthly_averages(self): expected = np.ones(15) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_seasonal_averages_with_DJF_and_drop_incomplete_seasons( self, ): @@ -2631,6 +2659,8 @@ def test_weights_for_seasonal_averages_with_DJF_and_drop_incomplete_seasons( ) assert np.allclose(result, expected, equal_nan=True) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_seasonal_averages_with_JFD(self): ds = self.ds.copy() @@ -2702,6 +2732,8 @@ def test_weights_for_seasonal_averages_with_JFD(self): ) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_custom_season_time_series_weights(self): ds = self.ds.copy() @@ -2780,6 +2812,8 @@ def test_custom_season_time_series_weights(self): ) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_daily_averages(self): ds = self.ds.copy() @@ -2873,6 +2907,8 @@ def test_weights_for_hourly_averages(self): expected = np.ones(15) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + class TestWeightsForClimatologyMode: @pytest.fixture(autouse=True) def setup(self): @@ -3035,6 +3071,8 @@ def test_weights_for_seasonal_climatology_with_DJF(self): assert np.allclose(result, expected, equal_nan=True) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_seasonal_climatology_with_JFD(self): ds = self.ds.copy() @@ -3101,6 +3139,8 @@ def test_weights_for_seasonal_climatology_with_JFD(self): ) assert np.allclose(result, expected, equal_nan=True) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_annual_climatology(self): ds = self.ds.copy() @@ -3165,6 +3205,8 @@ def test_weights_for_annual_climatology(self): ) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + def test_weights_for_daily_climatology(self): ds = self.ds.copy() @@ -3227,6 +3269,8 @@ def test_weights_for_daily_climatology(self): ) assert np.allclose(result, expected) + _check_each_weight_group_adds_up_to_1(ds, result) + class Test_Averager: # NOTE: This private method is tested because it is more redundant to diff --git a/xcdat/temporal.py b/xcdat/temporal.py index eaf69cea..2c58b4f1 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1221,9 +1221,6 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray: divided by the total sum of the time lengths in its group to get its corresponding weight. - The sum of the weights for each group is validated to ensure it equals - 1.0. - Parameters ---------- time_bounds : xr.DataArray