From a22a1ce055926d3ccb3be48e7e0d80d355d00cfd Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Wed, 9 Aug 2023 23:28:38 +0200 Subject: [PATCH] Cleanup --- src/y0/algorithm/transport.py | 8 ++++---- src/y0/graph.py | 18 +----------------- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/src/y0/algorithm/transport.py b/src/y0/algorithm/transport.py index 76ed928ab..8c1c950b2 100644 --- a/src/y0/algorithm/transport.py +++ b/src/y0/algorithm/transport.py @@ -25,6 +25,7 @@ TARGET_DOMAIN = Population("pi*") + # FIXME rename this def find_transport_nodes( *, @@ -45,12 +46,11 @@ def find_transport_nodes( surrogate_outcomes = {surrogate_outcomes} # Find the c_component with surrogate_outcomes - c_components = graph.get_c_components() c_component_surrogate_outcomes = set() - for index, component in enumerate(c_components): + for component in graph.get_c_components(): # Check if surrogate_outcomes is present in the current set if surrogate_outcomes.intersection(component): - c_component_surrogate_outcomes = c_component_surrogate_outcomes.union(component) + c_component_surrogate_outcomes.update(component) # subgraph where interventions in edges are removed interventions_overbar = graph.remove_in_edges(surrogate_interventions) @@ -147,7 +147,7 @@ def surrogate_to_transport( nodes=find_transport_nodes( surrogate_interventions=surrogate_interventions[domain], surrogate_outcomes=surrogate_outcomes[domain], - graph=graph + graph=graph, ), ) for domain in surrogate_outcomes diff --git a/src/y0/graph.py b/src/y0/graph.py index 3d7728cb4..2fd7cafae 100644 --- a/src/y0/graph.py +++ b/src/y0/graph.py @@ -13,14 +13,7 @@ from networkx.classes.reportviews import NodeView from networkx.utils import open_file -from .dsl import ( - CounterfactualVariable, - Intervention, - Transport, - Variable, - vmap_adj, - vmap_pairs, -) +from .dsl import CounterfactualVariable, Intervention, Variable, vmap_adj, vmap_pairs __all__ = [ "NxMixedGraph", @@ -163,15 +156,6 @@ def to_latent_variable_dag( tag=tag, ) - def get_transport_nodes(self) -> Set[Transport]: - """ - Returns - ------- - Transportability nodes. - - """ - return {t for t in self.nodes() if isinstance(t, Transport)} - @classmethod def from_latent_variable_dag(cls, graph: nx.DiGraph, tag: Optional[str] = None) -> NxMixedGraph: """Load a labeled DAG."""