Skip to content

Commit

Permalink
raise when time col not in cluster cols (#170)
Browse files Browse the repository at this point in the history
* raise when time col not in cluster cols

* release 0141
  • Loading branch information
david26694 authored May 3, 2024
1 parent 414133a commit 16b6b5c
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 3 deletions.
7 changes: 7 additions & 0 deletions cluster_experiments/random_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def __init__(
self.treatment_col = treatment_col
self.splitter_weights = splitter_weights
self.washover = washover or EmptyWashover()
self._check_clusters()

def _check_clusters(self):
"""Check if time_col is in cluster_cols"""
assert (
self.time_col in self.cluster_cols
), "in switchback splitters, time_col must be in cluster_cols"

def _get_time_col_cluster(self, df: pd.DataFrame) -> pd.Series:
df = df.copy()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

setup(
name="cluster_experiments",
version="0.14.0",
version="0.14.1",
packages=find_packages(),
extras_require={
"dev": dev_packages,
Expand Down
2 changes: 2 additions & 0 deletions tests/power_analysis/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def binary_hypothesis_power_config():
"perturbator": "constant",
"splitter": "non_clustered",
"n_simulations": 50,
"seed": 220924,
}
return PowerAnalysis.from_dict(config)

Expand All @@ -72,6 +73,7 @@ def multivariate_hypothesis_power_config():
"splitter": "non_clustered",
"n_simulations": 50,
"treatments": ["A", "B", "C", "D", "E", "F", "G"],
"seed": 220924,
}
return PowerAnalysis.from_dict(config)

Expand Down
4 changes: 2 additions & 2 deletions tests/power_analysis/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_power_analysis_constant_perturbator_seed(df):
pw = PowerAnalysis.from_dict(config_dict)
powers.append(pw.power_analysis(df, average_effect=10))

assert np.var(np.asarray(powers)) == 0
assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10)


def test_power_analysis_binary_perturbator_seed(df_binary):
Expand All @@ -33,4 +33,4 @@ def test_power_analysis_binary_perturbator_seed(df_binary):
pw = PowerAnalysis.from_dict(config_dict)
powers.append(pw.power_analysis(df_binary, average_effect=0.08))

assert np.var(np.asarray(powers)) == 0
assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10)
25 changes: 25 additions & 0 deletions tests/splitter/test_switchback_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,28 @@ def test_raise_time_col_not_in_df():
perturbator=perturbator,
analysis=analysis,
)


def test_raise_time_col_not_in_df_splitter():
with pytest.raises(
AssertionError,
match="in switchback splitters, time_col must be in cluster_cols",
):
data = pd.DataFrame(
{
"activation_time": pd.date_range(
start="2021-01-01", periods=10, freq="D"
),
"city": ["A" for _ in range(10)],
}
)
time_col = "activation_time"
switch_frequency = "6h"
cluster_cols = ["city"]

splitter = SwitchbackSplitter(
time_col=time_col,
cluster_cols=cluster_cols,
switch_frequency=switch_frequency,
)
_ = splitter.assign_treatment_df(data)

0 comments on commit 16b6b5c

Please sign in to comment.