From bb90fc5ea995217fb33aa2693f939463130bf8b0 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Wed, 9 Aug 2023 23:19:32 +0200 Subject: [PATCH] Cleanup --- src/y0/algorithm/transport.py | 318 ++++++++++++------------- tests/test_algorithm/test_transport.py | 20 +- 2 files changed, 162 insertions(+), 176 deletions(-) diff --git a/src/y0/algorithm/transport.py b/src/y0/algorithm/transport.py index fb28f24db..8d485655b 100644 --- a/src/y0/algorithm/transport.py +++ b/src/y0/algorithm/transport.py @@ -1,17 +1,19 @@ """Implement of surrogate outcomes and transportability. -..seealso:: https://arxiv.org/abs/1806.07172 +.. seealso:: https://arxiv.org/abs/1806.07172 """ +import logging from copy import deepcopy from dataclasses import dataclass -from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Tuple, Union +from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Tuple, Union, cast from y0.algorithm.conditional_independencies import are_d_separated from y0.dsl import ( + CounterfactualVariable, Expression, + Intervention, Population, - PopulationProbability, Product, Sum, Transport, @@ -22,6 +24,9 @@ __all__ = [ "transport", ] + +logger = logging.getLogger(__name__) + TARGET_DOMAIN = Population("pi*") @@ -63,37 +68,51 @@ def find_transport_vertices( ) +TRANSPORT_PREFIX = "T_" + + +def transport_variable(variable: Variable) -> Variable: + if isinstance(variable, (CounterfactualVariable, Intervention)): + raise TypeError + return Variable(TRANSPORT_PREFIX + variable.name) + + +def is_transport_node(node: Variable) -> bool: + return not isinstance(node, (CounterfactualVariable, Intervention)) and node.name.startswith( + TRANSPORT_PREFIX + ) + + +def get_transport_nodes(graph: NxMixedGraph) -> Set[Variable]: + return {node for node in graph if is_transport_node(node)} + + def create_transport_diagram( - transport_vertices: Union[Set[Variable], Variable], + transport_nodes: Union[Set[Variable], Variable], graph: NxMixedGraph, ) -> NxMixedGraph: - """ - Create a NxMixedGraph identical to graph but with transport vertices added. - :param transport_vertices: Vertices which have transport nodes pointing to them. + """Create a NxMixedGraph identical to graph but with transport vertices added. + + :param transport_nodes: Vertices which have transport nodes pointing to them. :param graph: The graph of the target domain. :returns: graph with transport vertices added """ - # TODO we discussed the possibility of using a dictionary with needed nodes - # instead of creating a graph for each diagram. - transportability_diagram = NxMixedGraph() + # instead of creating a graph for each diagram. + rv = NxMixedGraph() for node in graph.nodes(): - transportability_diagram.add_node(node) + rv.add_node(node) for u, v in graph.directed.edges(): - transportability_diagram.add_directed_edge(u, v) + rv.add_directed_edge(u, v) for u, v in graph.undirected.edges(): - transportability_diagram.add_undirected_edge(u, v) - - for vertex in transport_vertices: - # TODO Make this a true Transport instead of a Variable - # T_vertex = Transport(vertex) - T_vertex = Variable("T" + vertex.to_text()) - transportability_diagram.add_node(T_vertex) - transportability_diagram.add_directed_edge(T_vertex, vertex) - - return transportability_diagram + rv.add_undirected_edge(u, v) + for node in transport_nodes: + transport_node = transport_variable(node) + rv.add_directed_edge(transport_node, node) + return rv +@dataclass class TransportQuery: target_interventions: Set[Variable] target_outcomes: Set[Variable] @@ -122,6 +141,7 @@ def surrogate_to_transport( """ if set(surrogate_outcomes) != set(surrogate_interventions): raise ValueError("Inconsistent surrogate outcome and intervention domains") + transportability_diagrams = { domain: create_transport_diagram( graph, @@ -146,66 +166,50 @@ def surrogate_to_transport( def trso_line1( target_outcomes: Set[Variable], - probability: PopulationProbability, + expression: Expression, transportability_diagram: NxMixedGraph, -) -> Sum: - """ - Return the probability in the case where no interventions are present. +) -> Expression: + """Return the probability in the case where no interventions are present. + :param target_outcomes: A set of nodes that comprise our target outcomes. - :param probability : The distribution in the current domain. + :param expression : The distribution in the current domain. :param transportability_diagram : The graph with transport nodes in this domain. :returns: Sum over the probabilities of nodes other than target outcomes. - """ - return Sum.safe(probability, transportability_diagram.nodes() - target_outcomes) + return Sum.safe(expression, transportability_diagram.nodes() - target_outcomes) def trso_line2( query: TransportQuery, probability: Expression, domain: Variable, - outcomes_anc: Set[Variable], + outcomes_ancestors: Set[Variable], ) -> Tuple[TransportQuery, Expression]: - """ - Restrict the interventions and diagram to only include ancestors of target variables. - :param target_outcomes: A set of target variables for causal effects. - :param target_interventions: A set of interventions for the target domain. + """Restrict the interventions and diagram to only include ancestors of target variables. + + :param query: A transport query :param probability : The distribution in the current domain. - :param active_interventions : which interventions are currently active :param domain : current domain - :param transportability_diagrams : Dictionary of all available transportability diagrams - :param available_interventions : A dictionary of sets of Experiments available in each domain. - :param outcomes_anc : the ancestors of target variables in transportability_diagram + :param outcomes_ancestors : the ancestors of target variables in transportability_diagram :returns: Dictionary of modified trso inputs. - """ new_query = deepcopy(query) - new_query.target_interventions.intersection_update(outcomes_anc) + new_query.target_interventions.intersection_update(outcomes_ancestors) new_query.transportability_diagrams[domain] = new_query.transportability_diagrams[ domain - ].subgraph(outcomes_anc) + ].subgraph(outcomes_ancestors) new_expression = Sum.safe( - probability, new_query.transportability_diagrams[domain].nodes() - outcomes_anc + probability, new_query.transportability_diagrams[domain].nodes() - outcomes_ancestors ) - return (new_query, new_expression) + return new_query, new_expression -def trso_line3( - query: TransportQuery, - additional_interventions: Set[Variable], -) -> TransportQuery: +def trso_line3(query: TransportQuery, additional_interventions: Set[Variable]) -> TransportQuery: """ - :param target_outcomes: A set of target variables for causal effects. - :param target_interventions: A set of interventions for the target domain. - :param probability : The distribution in the current domain. - :param active_interventions : which interventions are currently active - :param domain : current domain - :param transportability_diagrams : Dictionary of all available transportability diagrams - :param available_interventions : A dictionary of sets of Experiments available in each domain. + :param query: A transport query :param additional_interventions : interventions to be added to target_interventions :returns: dictionary of modified trso inputs. - """ new_query = deepcopy(query) new_query.target_interventions.update(additional_interventions) @@ -217,18 +221,12 @@ def trso_line4( domain: Variable, components: Set[FrozenSet[Variable]], ) -> Dict[FrozenSet[Variable], TransportQuery]: - """Find the trso inputs for each c-component. + """Find the trso inputs for each C-component. - :param target_outcomes: A set of target variables for causal effects. - :param target_interventions: A set of interventions for the target domain. - :param probability : The distribution in the current domain. - :param active_interventions : which interventions are currently active + :param query: A transport query :param domain : current domain - :param transportability_diagrams : Dictionary of all available transportability diagrams - :param available_interventions : A dictionary of sets of Experiments available in each domain. :param components : Set of c_components of transportability_diagram without target_interventions :returns: Dictionary with components as keys and dictionary of modified trso inputs as values - """ transportability_diagram = query.transportability_diagrams[domain] rv = {} @@ -242,47 +240,40 @@ def trso_line4( def trso_line6( query: TransportQuery, - active_interventions: Set[Variable], domain: Variable, ) -> Dict[Population, Tuple[TransportQuery, Set[Variable]]]: """Find the active interventions in each diagram, run trso with active interventions. - :param target_outcomes: A set of target variables for causal effects. - :param target_interventions: A set of interventions for the target domain. - :param probability : The distribution in the current domain. - :param active_interventions : which interventions are currently active + :param query: A transport query :param domain : current domain - :param transportability_diagrams : Dictionary of all available transportability diagrams - :param available_interventions : A dictionary of sets of Experiments available in each domain. - :return List of Dictionary of modified trso inputs - + :returns: """ transportability_diagram = query.transportability_diagrams[domain] expressions = {} for loop_domain, loop_transportability_diagram in query.transportability_diagrams.items(): if not query.available_interventions[loop_domain].intersection(query.target_interventions): continue - transportability_nodes = loop_transportability_diagram.get_transport_nodes() + + transportability_nodes = get_transport_nodes(loop_transportability_diagram) diagram_without_interventions = loop_transportability_diagram.remove_in_edges( query.target_interventions ) - if not all( are_d_separated( diagram_without_interventions, - node, + transportability_node, outcome, conditions=query.target_interventions, ) - for node in transportability_nodes + for transportability_node in transportability_nodes for outcome in query.target_outcomes ): continue + new_query = deepcopy(query) new_query.target_interventions = ( query.target_interventions - query.available_interventions[loop_domain] ) - new_query.domain = loop_domain new_query.transportability_diagrams[domain] = transportability_diagram.subgraph( transportability_diagram.nodes() @@ -296,45 +287,54 @@ def trso_line6( return expressions +def trso_line9(query, expression, active_interventions, domain) -> Expression: + pass + + +def trso_line10( + query, expression, active_interventions, domain, district, new_available_interventions +) -> Expression: + pass + + # TODO Tikka paper says that topological ordering is available globaly -# TODO some functions need transportability_diagrams while others need transportability_diagram def trso( query: TransportQuery, active_interventions: Set[Variable], domain: Population, - probability: Expression, -) -> Expression: + expression: Expression, +) -> Optional[Expression]: # Check that domain is in query.domains # check that query.surrogate_interventions keys are equals to domains # check that query.transportability_diagrams keys are equal to domains transportability_diagram = query.transportability_diagrams[domain] # line 1 if not query.target_interventions: - return trso_line1(query.target_outcomes, probability, transportability_diagram) + return trso_line1(query.target_outcomes, expression, transportability_diagram) # line 2 - outcomes_anc = transportability_diagram.ancestors_inclusive(query.target_outcomes) - if transportability_diagram.nodes() - outcomes_anc: - new_query, new_probability = trso_line2( + outcome_ancestors = transportability_diagram.ancestors_inclusive(query.target_outcomes) + if transportability_diagram.nodes() - outcome_ancestors: + new_query, new_expression = trso_line2( query, - probability, + expression, domain, - outcomes_anc, + outcome_ancestors, ) return trso( query=new_query, active_interventions=active_interventions, domain=domain, - probability=new_probability, + expression=new_expression, ) # line 3 - + # TODO give meaningful name to this variable target_interventions_overbar = transportability_diagram.remove_in_edges( query.target_interventions ) additional_interventions = ( - transportability_diagram.nodes() + cast(set[Variable], transportability_diagram.nodes()) - query.target_interventions - target_interventions_overbar.ancestors_inclusive(query.target_outcomes) ) @@ -347,7 +347,7 @@ def trso( query=new_query, active_interventions=active_interventions, domain=domain, - probability=probability, + expression=expression, ) # line 4 @@ -363,87 +363,77 @@ def trso( return Sum.safe( Product.safe( - [ - trso( - query=trso_line4input, - active_interventions=active_interventions, - domain=domain, - probability=probability, - ) - for trso_line4input in trso_line4inputs.values() - ], + trso( + query=trso_line4input, + active_interventions=active_interventions, + domain=domain, + expression=expression, + ) + for trso_line4input in trso_line4inputs.values() ), transportability_diagram.nodes() - query.target_interventions.union(query.target_outcomes), ) - # line 5 - else: - # line 6 - if not active_interventions: - trso_inputs = trso_line6( - query, - active_interventions, - domain, + # line 6 + if not active_interventions: + subqueries = trso_line6(query, domain) + expressions = {} + for loop_domain, (loop_query, loop_active_interventions) in subqueries.items(): + loop_expression = trso( + query=loop_query, + active_interventions=loop_active_interventions, + domain=loop_domain, + expression=expression, ) - expressions = {} - for loop_domain, (loop_query, loop_active_interventions) in trso_inputs.items(): - expressionk = trso( - query=loop_query, - active_interventions=loop_active_interventions, - domain=loop_domain, - probability=probability, - ) - # line7 - if expressionk: - expressions[loop_domain] = expressionk - # return expressionk - if len(expressions) == 1: - return list(expressions.values())[0] - elif len(expressions) > 1: - # What if more than 1 expression doesn't fail? - # Is it non-deterministic or can we prove it will be length 1? - return list(expressions.values())[0] - # line8 - districts = transportability_diagram.get_c_components() - # line 11, return fail - if len(districts) <= 1: + # line7 + if loop_expression is not None: + expressions[loop_domain] = loop_expression + if len(expressions) == 1: + return list(expressions.values())[0] + elif len(expressions) > 1: + logger.warning("more than one expression were non-none") + # What if more than 1 expression doesn't fail? + # Is it non-deterministic or can we prove it will be length 1? + return list(expressions.values())[0] + + # line8 + districts = transportability_diagram.get_c_components() + # line 11, return fail + if len(districts) <= 1: + return None + # line 8, i.e. len(districts)>1 + + # line9 + if districts_without_interventions in districts: + return trso_line9( + query, + expression, + active_interventions, + domain, + ) + # line10 + for district in districts: + if not districts_without_interventions.issubset(district): + continue + # district is C' districts should be D[C'], but we chose to return set of nodes instead of subgraph + if len(active_interventions) == 0: + new_available_interventions = dict() + elif any( + is_transport_node(node) for node in transportability_diagram.get_markov_pillow(district) + ): return None - # line 8, i.e. len(districts)>1 - - # line9 - if districts_without_interventions in districts: - return trso_line9( - query, - probability, - active_interventions, - domain, - ) - # line10 - for district in districts: - if not districts_without_interventions.issubset(district): - continue - # district is C' districts should be D[C'], but we chose to return set of nodes instead of subgraph - if len(active_interventions) == 0: - new_available_interventions = dict() - elif any( - isinstance( - t, Transport - ) # TODO this doesn't match how we create transportability_diagram - for t in transportability_diagram.get_markov_pillow(district) - ): - return None - else: - new_available_interventions = query.available_interventions - - return trso_line10( - query, - probability, - active_interventions, - domain, - district, - new_available_interventions, - ) + else: + new_available_interventions = query.available_interventions + + return trso_line10( + query, + expression, + active_interventions, + domain, + district, + new_available_interventions, + ) def transport( diff --git a/tests/test_algorithm/test_transport.py b/tests/test_algorithm/test_transport.py index 5470a87b4..5610ddac9 100644 --- a/tests/test_algorithm/test_transport.py +++ b/tests/test_algorithm/test_transport.py @@ -117,8 +117,7 @@ def test_surrogate_to_transport(self): experiment_outcomes=experiment_outcomes, experiment_interventions=experiment_interventions, ) - target_domain = Variable("pi*") - domains = [Variable("pi1"), Variable("pi2")] + domains = [Pi1, Pi2] experiment_interventions, experiment_surrogate_outcomes = zip(*available_experiments) experiments_in_target_domain = set() @@ -133,7 +132,7 @@ def test_surrogate_to_transport(self): transportability_diagrams, tikka_trso_figure_8, domains, - target_domain, + TARGET_DOMAIN, experiment_interventions, experiments_in_target_domain, ) @@ -145,9 +144,8 @@ def test_trso_line1(self): outcomes = {Y1, Y2} interventions = {} active_interventions = {} - domain = Variable("pi*") domain_graph = tikka_trso_figure_8 - prob = PP[domain](*list(domain_graph.nodes())) + prob = PP[TARGET_DOMAIN](*list(domain_graph.nodes())) available_interventions = [{X2}, {X1}] expected = Sum.safe(prob, {W, X1, X2, Z}) @@ -228,8 +226,7 @@ def test_trso_line3(self): target_outcomes = {Y} active_interventions = {} available_interventions = {X} - domain = Variable("pi*") - prob = PP[domain](*list(transportability_diagram.nodes())) + prob = PP[TARGET_DOMAIN](*list(transportability_diagram.nodes())) target_interventions_overbar = transportability_diagram.remove_in_edges( target_interventions ) @@ -244,7 +241,7 @@ def test_trso_line3(self): target_interventions=target_interventions.union(additional_interventions), probability=prob, active_interventions=active_interventions, - domain=domain, + domain=TARGET_DOMAIN, transportability_diagram=transportability_diagram, available_interventions=available_interventions, ) @@ -265,9 +262,8 @@ def test_trso_line3(self): def test_trso_line4(self): target_outcomes = {Y1, Y2} target_interventions = {X1, X2} - domain = Variable("pi*") transportability_diagram = tikka_trso_figure_8 - prob = PP[domain](*list(transportability_diagram.nodes())) + prob = PP[TARGET_DOMAIN](*list(transportability_diagram.nodes())) active_interventions = {} available_interventions = {X1, X2} districts_without_interventions = transportability_diagram.subgraph( @@ -280,7 +276,7 @@ def test_trso_line4(self): target_interventions={X1, X2, Z, W, Y1}, probability=prob, active_interventions=active_interventions, - domain=domain, + domain=TARGET_DOMAIN, transportability_diagram=transportability_diagram, available_interventions=available_interventions, ), @@ -309,7 +305,7 @@ def test_trso_line4(self): target_interventions, prob, active_interventions, - domain, + TARGET_DOMAIN, transportability_diagram, available_interventions, districts_without_interventions,