Skip to content

Commit

Permalink
Merge pull request #381 from pymc-labs/refactor
Browse files Browse the repository at this point in the history
Major code refactor to unify quasi experiment classes
  • Loading branch information
drbenvincent authored Aug 22, 2024
2 parents c80ed16 + e0b0847 commit e55f23b
Show file tree
Hide file tree
Showing 51 changed files with 8,746 additions and 8,223 deletions.
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ We recommend that your contribution complies with the following guidelines befor
make doctest
```

- Doctest can also be run directly via pytest, which can be helpful to run only specific tests during development. The following commands run all doctests, only doctests in the pymc_models module, and only the doctests for the `ModelBuilder` class in pymc_models:
- Doctest can also be run directly via pytest, which can be helpful to run only specific tests during development. The following commands run all doctests, only doctests in the pymc_models module, and only the doctests for the `PyMCModel` class in pymc_models:

```bash
pytest --doctest-modules causalpy/
pytest --doctest-modules causalpy/pymc_models.py
pytest --doctest-modules causalpy/pmyc_models.py::causalpy.pymc_models.ModelBuilder
pytest --doctest-modules causalpy/pmyc_models.py::causalpy.pymc_models.PyMCModel
```

- To indicate a work in progress please mark the PR as `draft`. Drafts may be useful to (1) indicate you are working on something to avoid duplicated work, (2) request broad review of functionality or API, or (3) seek collaborators.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ df = (
)

# Run the analysis
result = cp.pymc_experiments.RegressionDiscontinuity(
result = cp.RegressionDiscontinuity(
df,
formula="all ~ 1 + age + treated",
running_variable_name="age",
Expand Down
30 changes: 25 additions & 5 deletions causalpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,38 @@
# limitations under the License.
import arviz as az

from causalpy import pymc_experiments, pymc_models, skl_experiments, skl_models
import causalpy.pymc_experiments as pymc_experiments # to be deprecated
import causalpy.pymc_models as pymc_models
import causalpy.skl_experiments as skl_experiments # to be deprecated
import causalpy.skl_models as skl_models
from causalpy.skl_models import create_causalpy_compatible_class
from causalpy.version import __version__

from .data import load_data
from .experiments.diff_in_diff import DifferenceInDifferences
from .experiments.instrumental_variable import InstrumentalVariable
from .experiments.inverse_propensity_weighting import InversePropensityWeighting
from .experiments.prepostfit import InterruptedTimeSeries, SyntheticControl
from .experiments.prepostnegd import PrePostNEGD
from .experiments.regression_discontinuity import RegressionDiscontinuity
from .experiments.regression_kink import RegressionKink

az.style.use("arviz-darkgrid")

__all__ = [
"pymc_experiments",
"__version__",
"DifferenceInDifferences",
"create_causalpy_compatible_class",
"InstrumentalVariable",
"InterruptedTimeSeries",
"InversePropensityWeighting",
"load_data",
"PrePostNEGD",
"pymc_experiments", # to be deprecated
"pymc_models",
"skl_experiments",
"RegressionDiscontinuity",
"RegressionKink",
"skl_experiments", # to be deprecated
"skl_models",
"load_data",
"__version__",
"SyntheticControl",
]
174 changes: 0 additions & 174 deletions causalpy/data_validation.py

This file was deleted.

13 changes: 13 additions & 0 deletions causalpy/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
80 changes: 80 additions & 0 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for quasi experimental designs.
"""

from abc import abstractmethod

from sklearn.base import RegressorMixin

from causalpy.pymc_models import PyMCModel
from causalpy.skl_models import create_causalpy_compatible_class


class BaseExperiment:
"""Base class for quasi experimental designs."""

supports_bayes: bool
supports_ols: bool

def __init__(self, model=None):
# Ensure we've made any provided Scikit Learn model (as identified as being type
# RegressorMixin) compatible with CausalPy by appending our custom methods.
if isinstance(model, RegressorMixin):
model = create_causalpy_compatible_class(model)

if model is not None:
self.model = model

if isinstance(self.model, PyMCModel) and not self.supports_bayes:
raise ValueError("Bayesian models not supported.")

if isinstance(self.model, RegressorMixin) and not self.supports_ols:
raise ValueError("OLS models not supported.")

if self.model is None:
raise ValueError("model not set or passed.")

@property
def idata(self):
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
return self.model.idata

def print_coefficients(self, round_to=None):
"""Ask the model to print its coefficients."""
self.model.print_coefficients(self.labels, round_to)

def plot(self, *args, **kwargs) -> tuple:
"""Plot the model.
Internally, this function dispatches to either `bayesian_plot` or `ols_plot`
depending on the model type.
"""
if isinstance(self.model, PyMCModel):
return self.bayesian_plot(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self.ols_plot(*args, **kwargs)
else:
raise ValueError("Unsupported model type")

@abstractmethod
def bayesian_plot(self, *args, **kwargs):
"""Abstract method for plotting the model."""
raise NotImplementedError("bayesian_plot method not yet implemented")

@abstractmethod
def ols_plot(self, *args, **kwargs):
"""Abstract method for plotting the model."""
raise NotImplementedError("ols_plot method not yet implemented")
Loading

0 comments on commit e55f23b

Please sign in to comment.