Skip to content

Commit

Permalink
Add workflow for getting results for case studies
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Dec 4, 2023
1 parent 88e10a3 commit d44840c
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 22 deletions.
111 changes: 90 additions & 21 deletions src/eliater/api.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
"""Implementation of Eliater workflow."""

from pathlib import Path
from typing import Optional, Union

import click
import pandas as pd

from eliater.discover_latent_nodes import remove_nuisance_variables
from eliater.examples import examples
from eliater.network_validation import add_ci_undirected_edges
from y0.algorithm.estimation import estimate_ace
from y0.algorithm.identify import identify_outcomes
from y0.dsl import Expression, Variable
from y0.dsl import Variable
from y0.graph import NxMixedGraph, _ensure_set
from y0.struct import CITest

from .discover_latent_nodes import remove_nuisance_variables
from .network_validation import add_ci_undirected_edges

__all__ = [
"workflow",
"reproduce",
]

HERE = Path(__file__).parent.resolve()
RESULTS_PATH = HERE.joinpath("case_studies.tsv")


def workflow(
graph: NxMixedGraph,
Expand All @@ -28,7 +34,7 @@ def workflow(
ci_significance_level: Optional[float] = None,
ace_bootstraps: int | None = None,
ace_significance_level: float | None = None,
) -> tuple[NxMixedGraph, Expression, float]:
):
"""Run the Eliater workflow.
This workflow has three parts:
Expand All @@ -53,27 +59,90 @@ def workflow(
:returns: A triple with a modified graph, the estimand, and the ACE value.
:raises ValueError: If the graph becomes unidentifiable throughout the workflow
"""
graph = add_ci_undirected_edges(
graph, data, method=ci_method, significance_level=ci_significance_level
)
treatments = _ensure_set(treatments)
outcomes = _ensure_set(outcomes)

Check warning on line 63 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L62-L63

Added lines #L62 - L63 were not covered by tests
estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes)
if estimand is None:

def _estimate_ace(_graph):
return estimate_ace(

Check warning on line 66 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L65-L66

Added lines #L65 - L66 were not covered by tests
graph=_graph,
treatments=list(treatments),
outcomes=list(outcomes),
data=data,
bootstraps=ace_bootstraps,
alpha=ace_significance_level,
)

input_estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes)

Check warning on line 75 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L75

Added line #L75 was not covered by tests
if input_estimand is None:
raise ValueError("input graph is not identifiable")
input_ace = _estimate_ace(graph)

Check warning on line 78 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L77-L78

Added lines #L77 - L78 were not covered by tests

graph_1 = add_ci_undirected_edges(

Check warning on line 80 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L80

Added line #L80 was not covered by tests
graph, data, method=ci_method, significance_level=ci_significance_level
)
graph_1_estimand = identify_outcomes(graph_1, treatments=treatments, outcomes=outcomes)

Check warning on line 83 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L83

Added line #L83 was not covered by tests
if graph_1_estimand is None:
raise ValueError("not identifiable after adding CI edges")
graph_1_ace = _estimate_ace(graph_1)
graph_1_ace_delta = graph_1_ace - input_ace

Check warning on line 87 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L85-L87

Added lines #L85 - L87 were not covered by tests

# TODO extend this to consider condition variables
graph = remove_nuisance_variables(graph, treatments=treatments, outcomes=outcomes)
estimand = identify_outcomes(graph, treatments=treatments, outcomes=outcomes)
if not estimand:
graph_2 = remove_nuisance_variables(graph_1, treatments=treatments, outcomes=outcomes)
graph_2_estimand = identify_outcomes(graph_2, treatments=treatments, outcomes=outcomes)

Check warning on line 91 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L90-L91

Added lines #L90 - L91 were not covered by tests
if not graph_2_estimand:
raise ValueError("not identifiable after removing nuisance variables")
graph_2_ace = _estimate_ace(graph_2)
graph_2_ace_delta = graph_2_ace - input_ace

Check warning on line 95 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L93-L95

Added lines #L93 - L95 were not covered by tests

ace = estimate_ace(
graph=graph,
treatments=list(treatments),
outcomes=list(outcomes),
data=data,
bootstraps=ace_bootstraps,
alpha=ace_significance_level,
return (

Check warning on line 97 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L97

Added line #L97 was not covered by tests
input_estimand,
input_ace,
graph_1,
graph_1_estimand,
graph_1_ace,
graph_1_ace_delta,
graph_2,
graph_2_estimand,
graph_2_ace,
graph_2_ace_delta,
)
return graph, estimand, ace


def reproduce():
"""Run this function to generate the results for the paper."""
click.echo("Make sure you're on the dev version of y0")
rows = []

Check warning on line 114 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L113-L114

Added lines #L113 - L114 were not covered by tests
for example in examples:
if example.data is None:
continue

Check warning on line 117 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L117

Added line #L117 was not covered by tests
for query in example.example_queries:
if len(query.treatments) != 1 or len(query.outcomes) != 1:
click.echo(f"[{example.name}] skipping query:")
continue

Check warning on line 121 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L120-L121

Added lines #L120 - L121 were not covered by tests

try:
record = workflow(

Check warning on line 124 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L123-L124

Added lines #L123 - L124 were not covered by tests
graph=example.graph,
data=example.data,
treatments=query.treatments,
outcomes=query.outcomes,
)
except Exception as e:
click.echo(f"[{example.name}] failed on query: {query.expression}")
click.secho(str(e), fg="red")
continue

Check warning on line 133 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L130-L133

Added lines #L130 - L133 were not covered by tests
rows.append(
(
example.name,
", ".join(sorted(t.name for t in query.treatments)),
", ".join(sorted(o.name for o in query.outcomes)),
*record,
)
)
df = pd.DataFrame(rows)
df.to_csv(RESULTS_PATH, sep="\t", index=False)
return df

Check warning on line 144 in src/eliater/api.py

View check run for this annotation

Codecov / codecov/patch

src/eliater/api.py#L142-L144

Added lines #L142 - L144 were not covered by tests


if __name__ == "__main__":
reproduce()
6 changes: 5 additions & 1 deletion src/eliater/discover_latent_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ def remove_nuisance_variables(
:param tag: The tag for which variables are latent
:return: the new graph after simplification
"""
rv = NxMixedGraph(
directed=graph.directed.copy(),
undirected=graph.undirected.copy(),
)
lv_dag = mark_nuisance_variables_as_latent(
graph=graph, treatments=treatments, outcomes=outcomes, tag=tag
graph=rv, treatments=treatments, outcomes=outcomes, tag=tag
)
simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag)
return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag)
Expand Down
8 changes: 8 additions & 0 deletions src/eliater/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@
from .t_cell_signaling_pathway import t_cell_signaling_example

__all__ = [
"examples",
# actual examples
"ecoli_transcription_example",
"sars_cov_2_example",
"t_cell_signaling_example",
"single_mediator_with_multiple_confounders_nuisances_discrete_example",
]

examples = [
ecoli_transcription_example,
sars_cov_2_example,
t_cell_signaling_example,
]

0 comments on commit d44840c

Please sign in to comment.