From b6b1fabe2053c16d18766a8eca323da56e0c8eb3 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Thu, 14 Dec 2023 11:51:56 +0100 Subject: [PATCH] Update analysis workflow (#19) This PR updates the analytical workflow. It now runs gets the estimand and ACE after applying each step of the workflow, then calculates the deltas against the original ACE for all steps after the initial one. It's set up to automatically run on all examples in the repo that have data, so it's important now to properly get data in (e.g. in #17) that are _fully_ documented (otherwise, these results can't be contextualized) --- .gitignore | 1 + src/eliater/api.py | 180 +++++++++++++++++++++++++------ src/eliater/examples/__init__.py | 16 ++- 3 files changed, 165 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index fefaf5f..de56800 100644 --- a/.gitignore +++ b/.gitignore @@ -899,3 +899,4 @@ FodyWeavers.xsd # End of https://www.toptal.com/developers/gitignore/api/macos,linux,windows,python,jupyternotebooks,jetbrains,pycharm,vim,emacs,visualstudiocode,visualstudio scratch/ +src/eliater/case_studies.tsv diff --git a/src/eliater/api.py b/src/eliater/api.py index 6134373..39f89cc 100644 --- a/src/eliater/api.py +++ b/src/eliater/api.py @@ -1,23 +1,46 @@ -"""Implementation of Eliater workflow.""" +"""Implementation of Eliater workflow. -import logging +To run the workflow and reproduce results on all examples in the +package, use ``python -m eliater.api``. +""" + +import warnings +from dataclasses import dataclass +from pathlib import Path from typing import Optional, Union +import click import pandas as pd +from eliater.discover_latent_nodes import remove_nuisance_variables +from eliater.examples import examples +from eliater.network_validation import add_ci_undirected_edges from y0.algorithm.estimation import estimate_ace from y0.algorithm.identify import identify_outcomes from y0.dsl import Expression, Variable from y0.graph import NxMixedGraph, _ensure_set from y0.struct import CITest -from .discover_latent_nodes import remove_nuisance_variables - __all__ = [ "workflow", + "reproduce", ] -logger = logging.getLogger(__name__) +HERE = Path(__file__).parent.resolve() +RESULTS_PATH = HERE.joinpath("case_studies.tsv") + +# Ignore all warnings +warnings.filterwarnings("ignore") + + +@dataclass +class Step: + """Represents the state after a step in the workflow.""" + + graph: NxMixedGraph + estimand: Expression + ace: float + ace_delta: float def workflow( @@ -26,23 +49,25 @@ def workflow( treatments: Union[Variable, set[Variable]], outcomes: Union[Variable, set[Variable]], *, + conditions: Union[None, Variable, set[Variable]] = None, ci_method: Optional[CITest] = None, ci_significance_level: Optional[float] = None, ace_bootstraps: int | None = None, ace_significance_level: float | None = None, -) -> tuple[NxMixedGraph, Expression, float]: +) -> list[Step]: """Run the Eliater workflow. - This workflow has two parts: + This workflow has three parts: 1. Add undirected edges between d-separated nodes for which a data-driven conditional independency test fails - 2. Remove nuissance variables. + 2. Remove nuisance variables. 3. Estimates the average causal effect (ACE) of the treatments on outcomes :param graph: An acyclic directed mixed graph :param data: Data associated with nodes in the graph :param treatments: The node or nodes that are treated :param outcomes: The node or nodes that are outcomes + :param conditions: Conditions on the query (currently not implemented for all parts) :param ci_method: The conditional independency test to use. If None, defaults to :data:`y0.struct.DEFAULT_CONTINUOUS_CI_TEST` for continuous data @@ -52,33 +77,128 @@ def workflow( the tested variables. If none, defaults to 0.01. :param ace_bootstraps: The number of bootstraps for calculating the ACE. Defaults to 0 (i.e., not used by default) :param ace_significance_level: The significance level for the ACE. Defaults to 0.05. - :returns: A triple with a modified graph, the estimand, and the ACE value. + :returns: A set of states after each step :raises ValueError: If the graph becomes unidentifiable throughout the workflow """ - logger.warning( - "TODO: add CI undirected edges with parameters %s, %s", ci_method, ci_significance_level - ) - # graph = add_ci_undirected_edges( - # graph, data, method=ci_method, significance_level=ci_significance_level - # ) treatments = _ensure_set(treatments) outcomes = _ensure_set(outcomes) - estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes) - if estimand is None: + + def _estimate_ace(_graph: NxMixedGraph) -> float: + return estimate_ace( + graph=_graph, + treatments=list(treatments), + outcomes=list(outcomes), + conditions=conditions, + data=data, + bootstraps=ace_bootstraps, + alpha=ace_significance_level, + ) + + def _identify(_graph: NxMixedGraph) -> Expression: + return identify_outcomes( + _graph, treatments=treatments, outcomes=outcomes, conditions=conditions + ) + + input_estimand = _identify(graph) + if input_estimand is None: + raise ValueError("input graph is not identifiable") + input_ace = _estimate_ace(graph) + initial = Step(graph=graph, estimand=input_estimand, ace=input_ace, ace_delta=0.0) + + graph_1 = add_ci_undirected_edges( + graph, data, method=ci_method, significance_level=ci_significance_level + ) + graph_1_estimand = _identify(graph_1) + if graph_1_estimand is None: raise ValueError("not identifiable after adding CI edges") + graph_1_ace = _estimate_ace(graph_1) + graph_1_ace_delta = graph_1_ace - input_ace + step_1 = Step( + graph=graph_1, estimand=graph_1_estimand, ace=graph_1_ace, ace_delta=graph_1_ace_delta + ) - # TODO extend this to consider condition variables - graph = remove_nuisance_variables(graph, treatments=treatments, outcomes=outcomes) - estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes) - if not estimand: + graph_2 = remove_nuisance_variables(graph_1, treatments=treatments, outcomes=outcomes) + graph_2_estimand = _identify(graph_2) + if not graph_2_estimand: raise ValueError("not identifiable after removing nuisance variables") - - ace = estimate_ace( - graph=graph, - treatments=list(treatments), - outcomes=list(outcomes), - data=data, - bootstraps=ace_bootstraps, - alpha=ace_significance_level, + graph_2_ace = _estimate_ace(graph_2) + graph_2_ace_delta = graph_2_ace - input_ace + step_2 = Step( + graph=graph_2, estimand=graph_2_estimand, ace=graph_2_ace, ace_delta=graph_2_ace_delta ) - return graph, estimand, ace + + return [initial, step_1, step_2] + + +@click.command() +def reproduce(): + """Run this function to generate the results for the paper.""" + click.echo("Make sure you're on the dev version of y0") + rows = [] + columns = [ + "name", + "treatments", + "outcomes", + "initial_nodes", + "initial_estimand", + "initial_ace", + "step_1_nodes", + "step_1_estimand", + "step_1_ace", + "step_1_ace_delta", + "step_2_nodes", + "step_2_estimand", + "step_2_ace", + "step_2_ace_delta", + ] + for example in examples: + if example.data is not None: + data = example.data + elif example.generate_data is not None: + data = example.generate_data(2000, seed=0) + else: + continue + + for query in example.example_queries: + click.echo(f"\n> {example.name}") + if len(query.treatments) != 1 or len(query.outcomes) != 1: + click.echo(f"[{example.name}] skipping query:") + continue + + try: + steps = workflow( + graph=example.graph, + data=data, + treatments=query.treatments, + outcomes=query.outcomes, + ) + except Exception as e: + click.echo(f"Failed on query: {query.expression}") + click.secho(f"{type(e).__name__}: {e}", fg="red") + continue + + parts = [] + for i, step in enumerate(steps): + parts.append(step.graph.directed.number_of_nodes()) + parts.append(step.estimand.to_y0()) + parts.append(round(step.ace, 4)) + if i > 0: + parts.append(round(step.ace_delta, 4)) + rows.append( + ( + example.name, + ", ".join(sorted(t.name for t in query.treatments)), + ", ".join(sorted(o.name for o in query.outcomes)), + *parts, + ) + ) + if not rows: + raise ValueError("No examples available!") + df = pd.DataFrame(rows, columns=columns) + df.to_csv(RESULTS_PATH, sep="\t", index=False) + click.echo(f"\nOutputting {len(rows)} results to {RESULTS_PATH}") + return df + + +if __name__ == "__main__": + reproduce() diff --git a/src/eliater/examples/__init__.py b/src/eliater/examples/__init__.py index 501defb..e098ebd 100644 --- a/src/eliater/examples/__init__.py +++ b/src/eliater/examples/__init__.py @@ -2,11 +2,16 @@ from .ecoli import ecoli_transcription_example from .frontdoor_backdoor_discrete import ( - single_mediator_with_multiple_confounders_nuisances_discrete_example, + single_mediator_with_multiple_confounders_nuisances_discrete_example as example_4, ) from .sars import sars_cov_2_example from .t_cell_signaling_pathway import t_cell_signaling_example +from ..frontdoor_backdoor.base import frontdoor_backdoor_example +from ..frontdoor_backdoor.example1 import multiple_mediators_single_confounder_example as example_1 from ..frontdoor_backdoor.example2 import example_2 +from ..frontdoor_backdoor.example3 import ( + multiple_mediators_confounders_nuisance_vars_example as example_3, +) __all__ = [ "examples", @@ -14,13 +19,20 @@ "ecoli_transcription_example", "sars_cov_2_example", "t_cell_signaling_example", - "single_mediator_with_multiple_confounders_nuisances_discrete_example", + "frontdoor_backdoor_example", + "example_1", "example_2", + "example_3", + "example_4", ] examples = [ ecoli_transcription_example, sars_cov_2_example, t_cell_signaling_example, + frontdoor_backdoor_example, + example_1, example_2, + example_3, + example_4, ]