Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Aug 9, 2023
1 parent fbe09ce commit a22a1ce
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/y0/algorithm/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

TARGET_DOMAIN = Population("pi*")


# FIXME rename this
def find_transport_nodes(
*,
Expand All @@ -45,12 +46,11 @@ def find_transport_nodes(
surrogate_outcomes = {surrogate_outcomes}

# Find the c_component with surrogate_outcomes
c_components = graph.get_c_components()
c_component_surrogate_outcomes = set()
for index, component in enumerate(c_components):
for component in graph.get_c_components():
# Check if surrogate_outcomes is present in the current set
if surrogate_outcomes.intersection(component):
c_component_surrogate_outcomes = c_component_surrogate_outcomes.union(component)
c_component_surrogate_outcomes.update(component)

# subgraph where interventions in edges are removed
interventions_overbar = graph.remove_in_edges(surrogate_interventions)
Expand Down Expand Up @@ -147,7 +147,7 @@ def surrogate_to_transport(
nodes=find_transport_nodes(
surrogate_interventions=surrogate_interventions[domain],
surrogate_outcomes=surrogate_outcomes[domain],
graph=graph
graph=graph,
),
)
for domain in surrogate_outcomes
Expand Down
18 changes: 1 addition & 17 deletions src/y0/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@
from networkx.classes.reportviews import NodeView
from networkx.utils import open_file

from .dsl import (
CounterfactualVariable,
Intervention,
Transport,
Variable,
vmap_adj,
vmap_pairs,
)
from .dsl import CounterfactualVariable, Intervention, Variable, vmap_adj, vmap_pairs

__all__ = [
"NxMixedGraph",
Expand Down Expand Up @@ -163,15 +156,6 @@ def to_latent_variable_dag(
tag=tag,
)

def get_transport_nodes(self) -> Set[Transport]:
"""
Returns
-------
Transportability nodes.
"""
return {t for t in self.nodes() if isinstance(t, Transport)}

@classmethod
def from_latent_variable_dag(cls, graph: nx.DiGraph, tag: Optional[str] = None) -> NxMixedGraph:
"""Load a labeled DAG."""
Expand Down

0 comments on commit a22a1ce

Please sign in to comment.