Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strict type checking #240

Merged
merged 7 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/y0/algorithm/conditional_independencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from itertools import combinations, groupby
from typing import Any

import networkx as nx
import pandas as pd
Expand Down Expand Up @@ -111,12 +112,15 @@ def test_conditional_independencies(
]


Policy = Callable[[DSeparationJudgement], Any]


def get_conditional_independencies(
graph: NxMixedGraph,
*,
policy=None,
policy: Policy | None = None,
max_conditions: int | None = None,
**kwargs,
**kwargs: Any,
) -> set[DSeparationJudgement]:
"""Get the conditional independencies from the given ADMG.

Expand All @@ -139,7 +143,9 @@ def get_conditional_independencies(
)


def minimal(judgements: Iterable[DSeparationJudgement], policy=None) -> set[DSeparationJudgement]:
def minimal(
judgements: Iterable[DSeparationJudgement], policy: Policy | None = None
) -> set[DSeparationJudgement]:
r"""Given some d-separations, reduces to a 'minimal' collection.

For independencies of the form $A \perp B | {C_1, C_2, ...}$, the minimal collection will
Expand Down
16 changes: 8 additions & 8 deletions src/y0/algorithm/counterfactual_transport/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def same_district(event: set[Variable], graph: NxMixedGraph) -> bool:
"""
if len(event) < 1:
return True
visited_districts: set[frozenset] = {
visited_districts: set[frozenset[Variable]] = {
graph.get_district(variable.get_base()) for variable in event
}
return len(visited_districts) == 1
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def validate_inputs_for_transport_district_intervening_on_parents( # noqa:C901

def _no_intervention_variables_in_domain(
*, district: Collection[Variable], interventions: Collection[Variable]
):
) -> bool:
r"""Check that a district in a graph contains no intervention veriables.

Helper function for the transport_district_intervening_on_parents algorithm
Expand All @@ -1033,7 +1033,7 @@ def _no_intervention_variables_in_domain(

def _no_transportability_nodes_in_domain(
*, district: Collection[Variable], domain_graph: NxMixedGraph
):
) -> bool:
r"""Check that a district in a graph contains no transportability nodes.

Helper function for the transport_district_intervening_on_parents algorithm from
Expand Down Expand Up @@ -1701,11 +1701,11 @@ class UnconditionalCFTResult(NamedTuple):
expression: Expression
event: Event | None

def display(self):
def display(self) -> None:
"""Display this result."""
from IPython.display import display

display(event_to_probability(self.event))
display(event_to_probability(self.event)) # type:ignore
display(self.expression)


Expand Down Expand Up @@ -1741,7 +1741,7 @@ def _event_from_counterfactuals_strict(
return rv


def _event_base(variable):
def _event_base(variable: Variable) -> Variable:
if isinstance(variable, CounterfactualVariable):
return CounterfactualVariable(
name=variable.name,
Expand Down Expand Up @@ -2271,11 +2271,11 @@ class ConditionalCFTResult(NamedTuple):
expression: Expression
event: list[tuple[Variable, Intervention]] | None

def display(self):
def display(self) -> None:
"""Display this result."""
from IPython.display import display

display(event_to_probability(self.event))
display(event_to_probability(self.event)) # type:ignore
display(self.expression)


Expand Down
8 changes: 6 additions & 2 deletions src/y0/algorithm/estimation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from contextlib import redirect_stdout
from typing import cast

import pandas as pd

Expand Down Expand Up @@ -109,6 +110,9 @@ def ananke_average_causal_effect(
# care of that explicitly below
causal_effect = CausalEffect(ananke_graph, treatment.name, outcome.name)

return causal_effect.compute_effect(
data, estimator=estimator, n_bootstraps=bootstraps or 0, alpha=alpha or 0.05
return cast(
float,
causal_effect.compute_effect(
data, estimator=estimator, n_bootstraps=bootstraps or 0, alpha=alpha or 0.05
),
)
19 changes: 10 additions & 9 deletions src/y0/algorithm/estimation/estimators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Implementation of ACE estimators."""

from typing import Literal
from typing import Literal, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_primal_ipw_point_estimate(
treatment_value=treatment_value,
outcome=outcome,
)
return np.mean(beta_primal).item()
return cast(float, np.mean(beta_primal).item())


def get_beta_primal(
Expand Down Expand Up @@ -194,23 +194,21 @@ def get_beta_primal(
return beta_primal


def fit_binary_model(data, formula, weights=None) -> GLM:
def fit_binary_model(data: pd.DataFrame, formula: str) -> GLM:
"""Fit a binary general linear model."""
return GLM.from_formula(
formula,
data=data,
family=Binomial(),
freq_weights=weights,
).fit()


def fit_continuous_glm(data, formula, weights=None) -> GLM:
def fit_continuous_glm(data: pd.DataFrame, formula: str) -> GLM:
"""Fit a continuous general linear model."""
return GLM.from_formula(
formula,
data=data,
family=Gaussian(),
freq_weights=weights,
).fit()


Expand All @@ -231,7 +229,10 @@ def get_state_space_map(data: pd.DataFrame) -> dict[Variable, Literal["binary",


def _log_odd_ratio(point_estimate_t1: float, point_estimate_t0: float) -> float:
return np.log(
(point_estimate_t1 / (1 - point_estimate_t1))
/ (point_estimate_t0 / (1 - point_estimate_t0))
return cast(
float,
np.log(
(point_estimate_t1 / (1 - point_estimate_t1))
/ (point_estimate_t0 / (1 - point_estimate_t0))
),
)
13 changes: 8 additions & 5 deletions src/y0/algorithm/estimation/linear_scm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for structural causal models (SCMs)."""

from statistics import fmean
from typing import cast

import pandas as pd
import sympy
Expand All @@ -15,6 +16,8 @@
"evaluate_lscm",
]

EvalRv = dict[sympy.Symbol, sympy.core.numbers.Rational]


def get_single_door(
graph: NxMixedGraph, data: pd.DataFrame
Expand Down Expand Up @@ -42,7 +45,7 @@ def get_single_door(
return rv


def evaluate_admg(graph, data: pd.DataFrame):
def evaluate_admg(graph: NxMixedGraph, data: pd.DataFrame) -> EvalRv:
"""Evaluate an acyclic directed mixed graph (ADMG)."""
params = {
sympy_nested("\\beta", source, target): mean
Expand All @@ -54,19 +57,19 @@ def evaluate_admg(graph, data: pd.DataFrame):

def evaluate_lscm(
linear_scm: dict[Variable, sympy.Expr], params: dict[sympy.Symbol, float]
) -> dict[sympy.Symbol, sympy.core.numbers.Rational]:
) -> EvalRv:
"""Assign values to the parameters and return variable assignments dictionary."""
expressions: dict[sympy.Symbol, sympy.Expr] = {
variable.to_sympy(): expression for variable, expression in linear_scm.items()
}
eqns = [sympy.Eq(lhs.subs(params), rhs.subs(params)) for lhs, rhs in expressions.items()]
return sympy.solve(eqns, list(expressions))
return cast(EvalRv, sympy.solve(eqns, list(expressions)))


def _main():
def _main() -> None:
import warnings

from y0.examples import examples
from y0.examples import examples # type:ignore[attr-defined]

warnings.filterwarnings("ignore")

Expand Down
16 changes: 10 additions & 6 deletions src/y0/algorithm/identify/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Iterable
from itertools import combinations
from typing import cast
from typing import Any, cast

from y0.dsl import (
CounterfactualVariable,
Expand All @@ -27,7 +27,7 @@
class World(frozenset[Intervention]):
"""A set of interventions corresponding to a "world"."""

def __contains__(self, item) -> bool:
def __contains__(self, item: Any) -> bool:
if not isinstance(item, Intervention):
raise TypeError(
f"can not check if non-intervention is in a world: ({type(item)}) {item}"
Expand All @@ -53,7 +53,9 @@ def has_same_function(node1: Variable, node2: Variable) -> bool:
) == is_not_self_intervened(node2)


def nodes_attain_same_value(graph: NxMixedGraph, event: Event, a: Variable, b: Variable) -> bool: # noqa:C901
def nodes_attain_same_value( # noqa:C901
graph: NxMixedGraph, event: Event, a: Variable, b: Variable
) -> bool:
"""Check if the two nodes attain the same value."""
if a == b:
return True
Expand Down Expand Up @@ -402,8 +404,10 @@ def stitch_counterfactual_and_dopplegangers(
return _both_ways(rv)


def _both_ways(s):
rv = set()
def _both_ways(
s: Iterable[tuple[CounterfactualVariable, CounterfactualVariable]],
) -> set[tuple[CounterfactualVariable, CounterfactualVariable]]:
rv: set[tuple[CounterfactualVariable, CounterfactualVariable]] = set()
for a, b in s:
rv.add((b, a))
return rv
Expand All @@ -414,7 +418,7 @@ def stitch_counterfactual_and_doppleganger_neighbors(
) -> set[tuple[CounterfactualVariable, CounterfactualVariable]]:
"""Stitch together a counterfactual variable with the dopplegangers of its neighbors in each world."""
rv = {
frozenset({u @ world_1, v @ world_2})
(u @ world_1, v @ world_2)
for world_1, world_2 in combinations(worlds, 2)
for u in graph.nodes()
for v in graph.undirected.neighbors(u)
Expand Down
11 changes: 6 additions & 5 deletions src/y0/algorithm/identify/id_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Expression,
Intervention,
One,
P,
Probability,
Product,
Sum,
Expand Down Expand Up @@ -225,10 +224,12 @@ def get_events_of_each_district(graph: NxMixedGraph, event: Event) -> DistrictIn
}


def get_events_of_district(graph, district, event) -> Event:
def get_events_of_district(
graph: NxMixedGraph, district: Collection[Variable], event: Event
) -> Event:
"""Create new events by intervening each node on the Markov pillow of the district.

If the node in in the original event, then the value of the new event is the same as the original event.
If the node in the original event, then the value of the new event is the same as the original event.

:param graph: an NxMixedGraph
:param district: a district of the graph
Expand Down Expand Up @@ -310,6 +311,6 @@ def id_star_line_9(cf_graph: NxMixedGraph) -> Probability:
interventions = get_cf_interventions(cf_graph.nodes())
bases = [node.get_base() for node in cf_graph.nodes()]
if len(interventions) > 0:
return P[interventions](bases)
return Probability.safe(bases, interventions=interventions)
else:
return P(bases)
return Probability.safe(bases)
7 changes: 4 additions & 3 deletions src/y0/algorithm/separation/sigma_separation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of sigma-separation from [forre2018]_."""

from collections.abc import Iterable, Sequence
from typing import cast

import networkx as nx
from more_itertools import triplewise
Expand Down Expand Up @@ -138,11 +139,11 @@ def _triple_helper(
)


def _has_either_edge(graph: NxMixedGraph, u, v) -> bool:
return graph.directed.has_edge(u, v) or graph.undirected.has_edge(u, v)
def _has_either_edge(graph: NxMixedGraph, u: Variable, v: Variable) -> bool:
return cast(bool, graph.directed.has_edge(u, v)) or cast(bool, graph.undirected.has_edge(u, v))


def _only_directed_edge(graph, u, v) -> bool:
def _only_directed_edge(graph: NxMixedGraph, u: Variable, v: Variable) -> bool:
return graph.directed.has_edge(u, v) and not graph.undirected.has_edge(u, v)


Expand Down
11 changes: 6 additions & 5 deletions src/y0/algorithm/taheri_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import textwrap
from collections.abc import Collection, Iterable
from io import StringIO
from pathlib import Path
from typing import NamedTuple

Expand Down Expand Up @@ -246,7 +247,7 @@ def iterate_lvdags(
yv = graph.copy()
for node in inducible_nodes:
yv.nodes[node][tag] = node in induced_latents
yield induced_latents, inducible_nodes - induced_latents, yv # type:ignore
yield induced_latents, inducible_nodes - induced_latents, yv


def draw_results(
Expand Down Expand Up @@ -294,7 +295,7 @@ def draw_results(
fig.savefig(_path, dpi=400)


def print_results(results: list[Result], file=None) -> None:
def print_results(results: list[Result], file: StringIO | None = None) -> None:
"""Print a set of results."""
rows = [
(
Expand All @@ -314,12 +315,12 @@ def print_results(results: list[Result], file=None) -> None:


@click.command()
@verbose_option
def main():
@verbose_option # type:ignore
def main() -> None:
"""Run the algorithm on the IGF graph with the PI3K/Erk example."""
import pystow

from y0.examples import igf_example
from y0.examples import igf_example # type:ignore

results = taheri_design_dag(igf_example.graph.directed, cause="PI3K", effect="Erk", stop=3)
# print_results(results)
Expand Down
4 changes: 3 additions & 1 deletion src/y0/algorithm/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def activate_domain_and_interventions(
raise NotImplementedError(f"Unhandled expression type: {type(expression)}")


def all_transports_d_separated(graph, target_interventions, target_outcomes) -> bool:
def all_transports_d_separated(
graph: NxMixedGraph, target_interventions: set[Variable], target_outcomes: set[Variable]
) -> bool:
"""Check if all target_interventions are d-separated from target_outcomes.

:param graph: The graph with transport nodes in this domain.
Expand Down
2 changes: 1 addition & 1 deletion src/y0/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

@click.group()
@click.version_option()
def main():
def main() -> None:
"""CLI for y0."""


Expand Down
Loading
Loading