From fbe09ce7380d97c75145250a8cef4bdc5fcb5e86 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Wed, 9 Aug 2023 23:26:35 +0200 Subject: [PATCH] More cleanup --- src/y0/algorithm/transport.py | 71 +++++++++++++------------- tests/test_algorithm/test_transport.py | 8 +-- 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/src/y0/algorithm/transport.py b/src/y0/algorithm/transport.py index 8d485655b..76ed928ab 100644 --- a/src/y0/algorithm/transport.py +++ b/src/y0/algorithm/transport.py @@ -1,7 +1,4 @@ -"""Implement of surrogate outcomes and transportability. - -.. seealso:: https://arxiv.org/abs/1806.07172 -""" +"""Implement of surrogate outcomes and transportability from https://arxiv.org/abs/1806.07172.""" import logging from copy import deepcopy @@ -16,7 +13,6 @@ Population, Product, Sum, - Transport, Variable, ) from y0.graph import NxMixedGraph @@ -29,21 +25,22 @@ TARGET_DOMAIN = Population("pi*") - -def find_transport_vertices( - interventions: Union[Set[Variable], Variable], +# FIXME rename this +def find_transport_nodes( + *, + surrogate_interventions: Union[Set[Variable], Variable], surrogate_outcomes: Union[Set[Variable], Variable], graph: NxMixedGraph, ) -> Set[Variable]: - """ - Identify which vertices the transport vertices should point to. - :param interventions: The interventions performed in an experiment. + """Identify which vertices the transport vertices should point to. + + :param surrogate_interventions: The interventions performed in an experiment. :param surrogate_outcomes: The outcomes observed in an experiment. :param graph: The graph of the target domain. :returns: A set of variables representing target domain nodes where transportability nodes should be added. """ - if isinstance(interventions, Variable): - interventions = {interventions} + if isinstance(surrogate_interventions, Variable): + surrogate_interventions = {surrogate_interventions} if isinstance(surrogate_outcomes, Variable): surrogate_outcomes = {surrogate_outcomes} @@ -56,12 +53,12 @@ def find_transport_vertices( c_component_surrogate_outcomes = c_component_surrogate_outcomes.union(component) # subgraph where interventions in edges are removed - interventions_overbar = graph.remove_in_edges(interventions) + interventions_overbar = graph.remove_in_edges(surrogate_interventions) # Ancestors of surrogate_outcomes in interventions_overbar Ancestors_surrogate_outcomes = interventions_overbar.ancestors_inclusive(surrogate_outcomes) # Descendants of interventions in graph - Descendants_interventions = graph.descendants_inclusive(interventions) + Descendants_interventions = graph.descendants_inclusive(surrogate_interventions) return (Descendants_interventions - surrogate_outcomes).union( c_component_surrogate_outcomes - Ancestors_surrogate_outcomes @@ -88,12 +85,13 @@ def get_transport_nodes(graph: NxMixedGraph) -> Set[Variable]: def create_transport_diagram( - transport_nodes: Union[Set[Variable], Variable], + *, + nodes: Union[Set[Variable], Variable], graph: NxMixedGraph, ) -> NxMixedGraph: """Create a NxMixedGraph identical to graph but with transport vertices added. - :param transport_nodes: Vertices which have transport nodes pointing to them. + :param nodes: Vertices which have transport nodes pointing to them. :param graph: The graph of the target domain. :returns: graph with transport vertices added """ @@ -106,7 +104,7 @@ def create_transport_diagram( rv.add_directed_edge(u, v) for u, v in graph.undirected.edges(): rv.add_undirected_edge(u, v) - for node in transport_nodes: + for node in nodes: transport_node = transport_variable(node) rv.add_directed_edge(transport_node, node) return rv @@ -136,7 +134,8 @@ def surrogate_to_transport( :param target_outcomes: A set of target variables for causal effects. :param target_interventions: A set of interventions for the target domain. :param graph: The graph of the target domain. - :param intervention_outcome_pairs : A set of Experiments available in each domain. + :param surrogate_outcomes: + :param surrogate_interventions: :returns: An octuple representing the query transformation of a surrogate outcome query. """ if set(surrogate_outcomes) != set(surrogate_interventions): @@ -144,9 +143,11 @@ def surrogate_to_transport( transportability_diagrams = { domain: create_transport_diagram( - graph, - find_transport_vertices( - surrogate_interventions[domain], surrogate_outcomes[domain], graph + graph=graph, + nodes=find_transport_nodes( + surrogate_interventions=surrogate_interventions[domain], + surrogate_outcomes=surrogate_outcomes[domain], + graph=graph ), ) for domain in surrogate_outcomes @@ -172,8 +173,8 @@ def trso_line1( """Return the probability in the case where no interventions are present. :param target_outcomes: A set of nodes that comprise our target outcomes. - :param expression : The distribution in the current domain. - :param transportability_diagram : The graph with transport nodes in this domain. + :param expression: The distribution in the current domain. + :param transportability_diagram: The graph with transport nodes in this domain. :returns: Sum over the probabilities of nodes other than target outcomes. """ return Sum.safe(expression, transportability_diagram.nodes() - target_outcomes) @@ -188,9 +189,9 @@ def trso_line2( """Restrict the interventions and diagram to only include ancestors of target variables. :param query: A transport query - :param probability : The distribution in the current domain. - :param domain : current domain - :param outcomes_ancestors : the ancestors of target variables in transportability_diagram + :param probability: The distribution in the current domain. + :param domain: current domain + :param outcomes_ancestors: the ancestors of target variables in transportability_diagram :returns: Dictionary of modified trso inputs. """ new_query = deepcopy(query) @@ -208,7 +209,7 @@ def trso_line3(query: TransportQuery, additional_interventions: Set[Variable]) - """ :param query: A transport query - :param additional_interventions : interventions to be added to target_interventions + :param additional_interventions: interventions to be added to target_interventions :returns: dictionary of modified trso inputs. """ new_query = deepcopy(query) @@ -224,8 +225,8 @@ def trso_line4( """Find the trso inputs for each C-component. :param query: A transport query - :param domain : current domain - :param components : Set of c_components of transportability_diagram without target_interventions + :param domain: current domain + :param components: Set of c_components of transportability_diagram without target_interventions :returns: Dictionary with components as keys and dictionary of modified trso inputs as values """ transportability_diagram = query.transportability_diagrams[domain] @@ -245,7 +246,7 @@ def trso_line6( """Find the active interventions in each diagram, run trso with active interventions. :param query: A transport query - :param domain : current domain + :param domain: current domain :returns: """ transportability_diagram = query.transportability_diagrams[domain] @@ -339,10 +340,7 @@ def trso( - target_interventions_overbar.ancestors_inclusive(query.target_outcomes) ) if additional_interventions: - new_query = trso_line3( - query, - additional_interventions, - ) + new_query = trso_line3(query, additional_interventions) return trso( query=new_query, active_interventions=active_interventions, @@ -412,12 +410,15 @@ def trso( active_interventions, domain, ) + # line10 + # FIXME why aren't results collated over all districts? then pick which one to return? for district in districts: if not districts_without_interventions.issubset(district): continue # district is C' districts should be D[C'], but we chose to return set of nodes instead of subgraph if len(active_interventions) == 0: + # FIXME is this even possible? doesn't line 6 check this and return something else? new_available_interventions = dict() elif any( is_transport_node(node) for node in transportability_diagram.get_markov_pillow(district) diff --git a/tests/test_algorithm/test_transport.py b/tests/test_algorithm/test_transport.py index 5610ddac9..78ea5b775 100644 --- a/tests/test_algorithm/test_transport.py +++ b/tests/test_algorithm/test_transport.py @@ -6,7 +6,7 @@ from y0.algorithm.transport import ( TARGET_DOMAIN, TransportQuery, - find_transport_vertices, + find_transport_nodes, surrogate_to_transport, transport, trso, @@ -92,15 +92,15 @@ class TestTransport(cases.GraphTestCase): def test_find_transport_vertices(self): expected = {X1, Y2} - actual = find_transport_vertices(X1, Y1, tikka_trso_figure_8) + actual = find_transport_nodes(X1, Y1, tikka_trso_figure_8) self.assertEqual(actual, expected) expected = {X2} - actual = find_transport_vertices({X2}, {Y2}, tikka_trso_figure_8) + actual = find_transport_nodes({X2}, {Y2}, tikka_trso_figure_8) self.assertEqual(actual, expected) # Test for multiple vertices in interventions and surrogate_outcomes expected = {X1, X2, Y1} - actual = find_transport_vertices({X2, X1}, {Y2, W}, tikka_trso_figure_8) + actual = find_transport_nodes({X2, X1}, {Y2, W}, tikka_trso_figure_8) self.assertEqual(actual, expected) def test_surrogate_to_transport(self):