From d44840c0baab47b5c70e1aeb581f19375b3624d2 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 4 Dec 2023 18:18:04 +0100 Subject: [PATCH] Add workflow for getting results for case studies cc @srtaheri --- src/eliater/api.py | 111 ++++++++++++++++++++++----- src/eliater/discover_latent_nodes.py | 6 +- src/eliater/examples/__init__.py | 8 ++ 3 files changed, 103 insertions(+), 22 deletions(-) diff --git a/src/eliater/api.py b/src/eliater/api.py index bceda25..82e01df 100644 --- a/src/eliater/api.py +++ b/src/eliater/api.py @@ -1,22 +1,28 @@ """Implementation of Eliater workflow.""" +from pathlib import Path from typing import Optional, Union +import click import pandas as pd +from eliater.discover_latent_nodes import remove_nuisance_variables +from eliater.examples import examples +from eliater.network_validation import add_ci_undirected_edges from y0.algorithm.estimation import estimate_ace from y0.algorithm.identify import identify_outcomes -from y0.dsl import Expression, Variable +from y0.dsl import Variable from y0.graph import NxMixedGraph, _ensure_set from y0.struct import CITest -from .discover_latent_nodes import remove_nuisance_variables -from .network_validation import add_ci_undirected_edges - __all__ = [ "workflow", + "reproduce", ] +HERE = Path(__file__).parent.resolve() +RESULTS_PATH = HERE.joinpath("case_studies.tsv") + def workflow( graph: NxMixedGraph, @@ -28,7 +34,7 @@ def workflow( ci_significance_level: Optional[float] = None, ace_bootstraps: int | None = None, ace_significance_level: float | None = None, -) -> tuple[NxMixedGraph, Expression, float]: +): """Run the Eliater workflow. This workflow has three parts: @@ -53,27 +59,90 @@ def workflow( :returns: A triple with a modified graph, the estimand, and the ACE value. :raises ValueError: If the graph becomes unidentifiable throughout the workflow """ - graph = add_ci_undirected_edges( - graph, data, method=ci_method, significance_level=ci_significance_level - ) treatments = _ensure_set(treatments) outcomes = _ensure_set(outcomes) - estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes) - if estimand is None: + + def _estimate_ace(_graph): + return estimate_ace( + graph=_graph, + treatments=list(treatments), + outcomes=list(outcomes), + data=data, + bootstraps=ace_bootstraps, + alpha=ace_significance_level, + ) + + input_estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes) + if input_estimand is None: + raise ValueError("input graph is not identifiable") + input_ace = _estimate_ace(graph) + + graph_1 = add_ci_undirected_edges( + graph, data, method=ci_method, significance_level=ci_significance_level + ) + graph_1_estimand = identify_outcomes(graph_1, treatments=treatments, outcomes=outcomes) + if graph_1_estimand is None: raise ValueError("not identifiable after adding CI edges") + graph_1_ace = _estimate_ace(graph_1) + graph_1_ace_delta = graph_1_ace - input_ace # TODO extend this to consider condition variables - graph = remove_nuisance_variables(graph, treatments=treatments, outcomes=outcomes) - estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes) - if not estimand: + graph_2 = remove_nuisance_variables(graph_1, treatments=treatments, outcomes=outcomes) + graph_2_estimand = identify_outcomes(graph_2, treatments=treatments, outcomes=outcomes) + if not graph_2_estimand: raise ValueError("not identifiable after removing nuisance variables") + graph_2_ace = _estimate_ace(graph_2) + graph_2_ace_delta = graph_2_ace - input_ace - ace = estimate_ace( - graph=graph, - treatments=list(treatments), - outcomes=list(outcomes), - data=data, - bootstraps=ace_bootstraps, - alpha=ace_significance_level, + return ( + input_estimand, + input_ace, + graph_1, + graph_1_estimand, + graph_1_ace, + graph_1_ace_delta, + graph_2, + graph_2_estimand, + graph_2_ace, + graph_2_ace_delta, ) - return graph, estimand, ace + + +def reproduce(): + """Run this function to generate the results for the paper.""" + click.echo("Make sure you're on the dev version of y0") + rows = [] + for example in examples: + if example.data is None: + continue + for query in example.example_queries: + if len(query.treatments) != 1 or len(query.outcomes) != 1: + click.echo(f"[{example.name}] skipping query:") + continue + + try: + record = workflow( + graph=example.graph, + data=example.data, + treatments=query.treatments, + outcomes=query.outcomes, + ) + except Exception as e: + click.echo(f"[{example.name}] failed on query: {query.expression}") + click.secho(str(e), fg="red") + continue + rows.append( + ( + example.name, + ", ".join(sorted(t.name for t in query.treatments)), + ", ".join(sorted(o.name for o in query.outcomes)), + *record, + ) + ) + df = pd.DataFrame(rows) + df.to_csv(RESULTS_PATH, sep="\t", index=False) + return df + + +if __name__ == "__main__": + reproduce() diff --git a/src/eliater/discover_latent_nodes.py b/src/eliater/discover_latent_nodes.py index 5592092..b582293 100644 --- a/src/eliater/discover_latent_nodes.py +++ b/src/eliater/discover_latent_nodes.py @@ -171,8 +171,12 @@ def remove_nuisance_variables( :param tag: The tag for which variables are latent :return: the new graph after simplification """ + rv = NxMixedGraph( + directed=graph.directed.copy(), + undirected=graph.undirected.copy(), + ) lv_dag = mark_nuisance_variables_as_latent( - graph=graph, treatments=treatments, outcomes=outcomes, tag=tag + graph=rv, treatments=treatments, outcomes=outcomes, tag=tag ) simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag) return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag) diff --git a/src/eliater/examples/__init__.py b/src/eliater/examples/__init__.py index d3f6672..2d08321 100644 --- a/src/eliater/examples/__init__.py +++ b/src/eliater/examples/__init__.py @@ -8,8 +8,16 @@ from .t_cell_signaling_pathway import t_cell_signaling_example __all__ = [ + "examples", + # actual examples "ecoli_transcription_example", "sars_cov_2_example", "t_cell_signaling_example", "single_mediator_with_multiple_confounders_nuisances_discrete_example", ] + +examples = [ + ecoli_transcription_example, + sars_cov_2_example, + t_cell_signaling_example, +]