Skip to content

Commit

Permalink
Add washover input validation (#156)
Browse files Browse the repository at this point in the history
* added error messages for wrong input

* update matplotlip requirement for apple silico

* revert washover file

* add input warnings

* added docstring and adjusted error messages

* added tests

* moved df creation up and included comments

* updated version  to 13.0
  • Loading branch information
pinsacco authored Feb 28, 2024
1 parent 02a0cc4 commit 0885537
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
33 changes: 33 additions & 0 deletions cluster_experiments/washover.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,35 @@
class Washover(ABC):
"""Abstract class to model washovers in the switchback splitter."""

def _validate_columns(
self,
df: pd.DataFrame,
truncated_time_col: str,
cluster_cols: List[str],
original_time_col: str,
):
"""Validate that all the columns required for the washover are present in the dataframe.
Args:
df (pd.DataFrame): Input dataframe.
truncated_time_col (str): Name of the truncated time column.
cluster_cols (List[str]): List of clusters of experiment.
original_time_col (str): Name of the original time column.
Returns:
None: This method does not return any data; it only performs validation.
"""
if original_time_col not in df.columns:
raise ValueError(
f"{original_time_col = } is not in the dataframe columns and/or not specified as an input."
)
if truncated_time_col not in cluster_cols:
raise ValueError(f"{truncated_time_col = } is not in the cluster columns.")
for col in cluster_cols:
if col not in df.columns:
raise ValueError(f"{col = } cluster is not in the dataframe columns.")

@abstractmethod
def washover(
self,
Expand Down Expand Up @@ -165,6 +194,10 @@ def generate_data(start_time, end_time, treatment):
if original_time_col
else _original_time_column(truncated_time_col)
)

# Validate columns
self._validate_columns(df, truncated_time_col, cluster_cols, original_time_col)

# Cluster columns that do not involve time
non_time_cols = list(set(cluster_cols) - set([truncated_time_col]))
# For each cluster, we need to check if treatment has changed wrt last time
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@
"jinja2<3.1.0",
"mkdocs-jupyter==0.22.0",
"plotnine==0.8.0",
"matplotlib==3.4.3",
"matplotlib>=3.4.3",
]

dev_packages = test_packages + util_packages + docs_packages

setup(
name="cluster_experiments",
version="0.12.0",
version="0.13.0",
packages=find_packages(),
extras_require={
"dev": dev_packages,
Expand Down
51 changes: 51 additions & 0 deletions tests/splitter/test_washover.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from datetime import timedelta

import pandas as pd
import pytest

from cluster_experiments import SwitchbackSplitter
Expand Down Expand Up @@ -139,3 +140,53 @@ class Cfg:
assert not washover_df.query("time >= '2022-01-01 01:00:00'").equals(
out_df.query("time >= '2022-01-01 01:00:00'")
)


def test_truncated_time_not_in_cluster_cols():
msg = "is not in the cluster columns."
df = pd.DataFrame(columns=["time_bin", "city", "time", "treatment"])

# Check that the truncated_time_col is also included in the cluster_cols,
# An error is raised because "time_bin" is not in the cluster_cols
with pytest.raises(ValueError, match=msg):

ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover(
df=df,
truncated_time_col="time_bin",
cluster_cols=["city"],
original_time_col="time",
treatment_col="treatment",
)


def test_missing_original_time_col():
msg = "columns and/or not specified as an input."
df = pd.DataFrame(columns=["time_bin", "city", "treatment"])

# Check that the original_time_col is specifed as an input and in the dataframe columns
# An error is raised because "time" is not specified as an input for the washover
with pytest.raises(ValueError, match=msg):

ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover(
df=df,
truncated_time_col="time_bin",
cluster_cols=["city", "time_bin"],
treatment_col="treatment",
)


def test_cluster_cols_missing_in_df():
msg = "cluster is not in the dataframe columns."
df = pd.DataFrame(columns=["time_bin", "time", "treatment"])

# Check that all the cluster_cols are in the dataframe columns
# An error is raised because "city" is not in the dataframe columns
with pytest.raises(ValueError, match=msg):

ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover(
df=df,
truncated_time_col="time_bin",
cluster_cols=["city", "time_bin"],
original_time_col="time",
treatment_col="treatment",
)

0 comments on commit 0885537

Please sign in to comment.