Skip to content

Commit

Permalink
Update analysis workflow (#19)
Browse files Browse the repository at this point in the history
This PR updates the analytical workflow. It now runs gets the estimand
and ACE after applying each step of the workflow, then calculates the
deltas against the original ACE for all steps after the initial one.

It's set up to automatically run on all examples in the repo that have
data, so it's important now to properly get data in (e.g. in #17) that
are _fully_ documented (otherwise, these results can't be
contextualized)
  • Loading branch information
cthoyt authored Dec 14, 2023
1 parent b8df19f commit b6b1fab
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,4 @@ FodyWeavers.xsd
# End of https://www.toptal.com/developers/gitignore/api/macos,linux,windows,python,jupyternotebooks,jetbrains,pycharm,vim,emacs,visualstudiocode,visualstudio

scratch/
src/eliater/case_studies.tsv
180 changes: 150 additions & 30 deletions src/eliater/api.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,46 @@
"""Implementation of Eliater workflow."""
"""Implementation of Eliater workflow.
import logging
To run the workflow and reproduce results on all examples in the
package, use ``python -m eliater.api``.
"""

import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

import click
import pandas as pd

from eliater.discover_latent_nodes import remove_nuisance_variables
from eliater.examples import examples
from eliater.network_validation import add_ci_undirected_edges
from y0.algorithm.estimation import estimate_ace
from y0.algorithm.identify import identify_outcomes
from y0.dsl import Expression, Variable
from y0.graph import NxMixedGraph, _ensure_set
from y0.struct import CITest

from .discover_latent_nodes import remove_nuisance_variables

__all__ = [
"workflow",
"reproduce",
]

logger = logging.getLogger(__name__)
HERE = Path(__file__).parent.resolve()
RESULTS_PATH = HERE.joinpath("case_studies.tsv")

# Ignore all warnings
warnings.filterwarnings("ignore")


@dataclass
class Step:
"""Represents the state after a step in the workflow."""

graph: NxMixedGraph
estimand: Expression
ace: float
ace_delta: float


def workflow(
Expand All @@ -26,23 +49,25 @@ def workflow(
treatments: Union[Variable, set[Variable]],
outcomes: Union[Variable, set[Variable]],
*,
conditions: Union[None, Variable, set[Variable]] = None,
ci_method: Optional[CITest] = None,
ci_significance_level: Optional[float] = None,
ace_bootstraps: int | None = None,
ace_significance_level: float | None = None,
) -> tuple[NxMixedGraph, Expression, float]:
) -> list[Step]:
"""Run the Eliater workflow.
This workflow has two parts:
This workflow has three parts:
1. Add undirected edges between d-separated nodes for which a data-driven conditional independency test fails
2. Remove nuissance variables.
2. Remove nuisance variables.
3. Estimates the average causal effect (ACE) of the treatments on outcomes
:param graph: An acyclic directed mixed graph
:param data: Data associated with nodes in the graph
:param treatments: The node or nodes that are treated
:param outcomes: The node or nodes that are outcomes
:param conditions: Conditions on the query (currently not implemented for all parts)
:param ci_method:
The conditional independency test to use. If None, defaults to
:data:`y0.struct.DEFAULT_CONTINUOUS_CI_TEST` for continuous data
Expand All @@ -52,33 +77,128 @@ def workflow(
the tested variables. If none, defaults to 0.01.
:param ace_bootstraps: The number of bootstraps for calculating the ACE. Defaults to 0 (i.e., not used by default)
:param ace_significance_level: The significance level for the ACE. Defaults to 0.05.
:returns: A triple with a modified graph, the estimand, and the ACE value.
:returns: A set of states after each step
:raises ValueError: If the graph becomes unidentifiable throughout the workflow
"""
logger.warning(
"TODO: add CI undirected edges with parameters %s, %s", ci_method, ci_significance_level
)
# graph = add_ci_undirected_edges(
# graph, data, method=ci_method, significance_level=ci_significance_level
# )
treatments = _ensure_set(treatments)
outcomes = _ensure_set(outcomes)
estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes)
if estimand is None:

def _estimate_ace(_graph: NxMixedGraph) -> float:
return estimate_ace(
graph=_graph,
treatments=list(treatments),
outcomes=list(outcomes),
conditions=conditions,
data=data,
bootstraps=ace_bootstraps,
alpha=ace_significance_level,
)

def _identify(_graph: NxMixedGraph) -> Expression:
return identify_outcomes(
_graph, treatments=treatments, outcomes=outcomes, conditions=conditions
)

input_estimand = _identify(graph)
if input_estimand is None:
raise ValueError("input graph is not identifiable")
input_ace = _estimate_ace(graph)
initial = Step(graph=graph, estimand=input_estimand, ace=input_ace, ace_delta=0.0)

graph_1 = add_ci_undirected_edges(
graph, data, method=ci_method, significance_level=ci_significance_level
)
graph_1_estimand = _identify(graph_1)
if graph_1_estimand is None:
raise ValueError("not identifiable after adding CI edges")
graph_1_ace = _estimate_ace(graph_1)
graph_1_ace_delta = graph_1_ace - input_ace
step_1 = Step(
graph=graph_1, estimand=graph_1_estimand, ace=graph_1_ace, ace_delta=graph_1_ace_delta
)

# TODO extend this to consider condition variables
graph = remove_nuisance_variables(graph, treatments=treatments, outcomes=outcomes)
estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes)
if not estimand:
graph_2 = remove_nuisance_variables(graph_1, treatments=treatments, outcomes=outcomes)
graph_2_estimand = _identify(graph_2)
if not graph_2_estimand:
raise ValueError("not identifiable after removing nuisance variables")

ace = estimate_ace(
graph=graph,
treatments=list(treatments),
outcomes=list(outcomes),
data=data,
bootstraps=ace_bootstraps,
alpha=ace_significance_level,
graph_2_ace = _estimate_ace(graph_2)
graph_2_ace_delta = graph_2_ace - input_ace
step_2 = Step(
graph=graph_2, estimand=graph_2_estimand, ace=graph_2_ace, ace_delta=graph_2_ace_delta
)
return graph, estimand, ace

return [initial, step_1, step_2]


@click.command()
def reproduce():
"""Run this function to generate the results for the paper."""
click.echo("Make sure you're on the dev version of y0")
rows = []
columns = [
"name",
"treatments",
"outcomes",
"initial_nodes",
"initial_estimand",
"initial_ace",
"step_1_nodes",
"step_1_estimand",
"step_1_ace",
"step_1_ace_delta",
"step_2_nodes",
"step_2_estimand",
"step_2_ace",
"step_2_ace_delta",
]
for example in examples:
if example.data is not None:
data = example.data
elif example.generate_data is not None:
data = example.generate_data(2000, seed=0)
else:
continue

for query in example.example_queries:
click.echo(f"\n> {example.name}")
if len(query.treatments) != 1 or len(query.outcomes) != 1:
click.echo(f"[{example.name}] skipping query:")
continue

try:
steps = workflow(
graph=example.graph,
data=data,
treatments=query.treatments,
outcomes=query.outcomes,
)
except Exception as e:
click.echo(f"Failed on query: {query.expression}")
click.secho(f"{type(e).__name__}: {e}", fg="red")
continue

parts = []
for i, step in enumerate(steps):
parts.append(step.graph.directed.number_of_nodes())
parts.append(step.estimand.to_y0())
parts.append(round(step.ace, 4))
if i > 0:
parts.append(round(step.ace_delta, 4))
rows.append(
(
example.name,
", ".join(sorted(t.name for t in query.treatments)),
", ".join(sorted(o.name for o in query.outcomes)),
*parts,
)
)
if not rows:
raise ValueError("No examples available!")
df = pd.DataFrame(rows, columns=columns)
df.to_csv(RESULTS_PATH, sep="\t", index=False)
click.echo(f"\nOutputting {len(rows)} results to {RESULTS_PATH}")
return df


if __name__ == "__main__":
reproduce()
16 changes: 14 additions & 2 deletions src/eliater/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@

from .ecoli import ecoli_transcription_example
from .frontdoor_backdoor_discrete import (
single_mediator_with_multiple_confounders_nuisances_discrete_example,
single_mediator_with_multiple_confounders_nuisances_discrete_example as example_4,
)
from .sars import sars_cov_2_example
from .t_cell_signaling_pathway import t_cell_signaling_example
from ..frontdoor_backdoor.base import frontdoor_backdoor_example
from ..frontdoor_backdoor.example1 import multiple_mediators_single_confounder_example as example_1
from ..frontdoor_backdoor.example2 import example_2
from ..frontdoor_backdoor.example3 import (
multiple_mediators_confounders_nuisance_vars_example as example_3,
)

__all__ = [
"examples",
# actual examples
"ecoli_transcription_example",
"sars_cov_2_example",
"t_cell_signaling_example",
"single_mediator_with_multiple_confounders_nuisances_discrete_example",
"frontdoor_backdoor_example",
"example_1",
"example_2",
"example_3",
"example_4",
]

examples = [
ecoli_transcription_example,
sars_cov_2_example,
t_cell_signaling_example,
frontdoor_backdoor_example,
example_1,
example_2,
example_3,
example_4,
]

0 comments on commit b6b1fab

Please sign in to comment.