From 16b6b5cec3db87b94deb317179d5b2ce88628621 Mon Sep 17 00:00:00 2001 From: David Masip Date: Fri, 3 May 2024 10:16:04 +0200 Subject: [PATCH] raise when time col not in cluster cols (#170) * raise when time col not in cluster cols * release 0141 --- cluster_experiments/random_splitter.py | 7 ++++++ setup.py | 2 +- tests/power_analysis/test_multivariate.py | 2 ++ tests/power_analysis/test_seed.py | 4 ++-- tests/splitter/test_switchback_splitter.py | 25 ++++++++++++++++++++++ 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/cluster_experiments/random_splitter.py b/cluster_experiments/random_splitter.py index b8c3404..439225b 100644 --- a/cluster_experiments/random_splitter.py +++ b/cluster_experiments/random_splitter.py @@ -173,6 +173,13 @@ def __init__( self.treatment_col = treatment_col self.splitter_weights = splitter_weights self.washover = washover or EmptyWashover() + self._check_clusters() + + def _check_clusters(self): + """Check if time_col is in cluster_cols""" + assert ( + self.time_col in self.cluster_cols + ), "in switchback splitters, time_col must be in cluster_cols" def _get_time_col_cluster(self, df: pd.DataFrame) -> pd.Series: df = df.copy() diff --git a/setup.py b/setup.py index 5db5bd9..dc8a811 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( name="cluster_experiments", - version="0.14.0", + version="0.14.1", packages=find_packages(), extras_require={ "dev": dev_packages, diff --git a/tests/power_analysis/test_multivariate.py b/tests/power_analysis/test_multivariate.py index 9af3319..640f62c 100644 --- a/tests/power_analysis/test_multivariate.py +++ b/tests/power_analysis/test_multivariate.py @@ -60,6 +60,7 @@ def binary_hypothesis_power_config(): "perturbator": "constant", "splitter": "non_clustered", "n_simulations": 50, + "seed": 220924, } return PowerAnalysis.from_dict(config) @@ -72,6 +73,7 @@ def multivariate_hypothesis_power_config(): "splitter": "non_clustered", "n_simulations": 50, "treatments": ["A", "B", "C", "D", "E", "F", "G"], + "seed": 220924, } return PowerAnalysis.from_dict(config) diff --git a/tests/power_analysis/test_seed.py b/tests/power_analysis/test_seed.py index aff1c2d..842996f 100644 --- a/tests/power_analysis/test_seed.py +++ b/tests/power_analysis/test_seed.py @@ -22,7 +22,7 @@ def test_power_analysis_constant_perturbator_seed(df): pw = PowerAnalysis.from_dict(config_dict) powers.append(pw.power_analysis(df, average_effect=10)) - assert np.var(np.asarray(powers)) == 0 + assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10) def test_power_analysis_binary_perturbator_seed(df_binary): @@ -33,4 +33,4 @@ def test_power_analysis_binary_perturbator_seed(df_binary): pw = PowerAnalysis.from_dict(config_dict) powers.append(pw.power_analysis(df_binary, average_effect=0.08)) - assert np.var(np.asarray(powers)) == 0 + assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10) diff --git a/tests/splitter/test_switchback_splitter.py b/tests/splitter/test_switchback_splitter.py index 5542a90..3c88594 100644 --- a/tests/splitter/test_switchback_splitter.py +++ b/tests/splitter/test_switchback_splitter.py @@ -93,3 +93,28 @@ def test_raise_time_col_not_in_df(): perturbator=perturbator, analysis=analysis, ) + + +def test_raise_time_col_not_in_df_splitter(): + with pytest.raises( + AssertionError, + match="in switchback splitters, time_col must be in cluster_cols", + ): + data = pd.DataFrame( + { + "activation_time": pd.date_range( + start="2021-01-01", periods=10, freq="D" + ), + "city": ["A" for _ in range(10)], + } + ) + time_col = "activation_time" + switch_frequency = "6h" + cluster_cols = ["city"] + + splitter = SwitchbackSplitter( + time_col=time_col, + cluster_cols=cluster_cols, + switch_frequency=switch_frequency, + ) + _ = splitter.assign_treatment_df(data)