Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Aug 9, 2023
1 parent bb90fc5 commit fbe09ce
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 39 deletions.
71 changes: 36 additions & 35 deletions src/y0/algorithm/transport.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Implement of surrogate outcomes and transportability.
.. seealso:: https://arxiv.org/abs/1806.07172
"""
"""Implement of surrogate outcomes and transportability from https://arxiv.org/abs/1806.07172."""

import logging
from copy import deepcopy
Expand All @@ -16,7 +13,6 @@
Population,
Product,
Sum,
Transport,
Variable,
)
from y0.graph import NxMixedGraph
Expand All @@ -29,21 +25,22 @@

TARGET_DOMAIN = Population("pi*")


def find_transport_vertices(
interventions: Union[Set[Variable], Variable],
# FIXME rename this
def find_transport_nodes(
*,
surrogate_interventions: Union[Set[Variable], Variable],
surrogate_outcomes: Union[Set[Variable], Variable],
graph: NxMixedGraph,
) -> Set[Variable]:
"""
Identify which vertices the transport vertices should point to.
:param interventions: The interventions performed in an experiment.
"""Identify which vertices the transport vertices should point to.
:param surrogate_interventions: The interventions performed in an experiment.
:param surrogate_outcomes: The outcomes observed in an experiment.
:param graph: The graph of the target domain.
:returns: A set of variables representing target domain nodes where transportability nodes should be added.
"""
if isinstance(interventions, Variable):
interventions = {interventions}
if isinstance(surrogate_interventions, Variable):
surrogate_interventions = {surrogate_interventions}
if isinstance(surrogate_outcomes, Variable):
surrogate_outcomes = {surrogate_outcomes}

Expand All @@ -56,12 +53,12 @@ def find_transport_vertices(
c_component_surrogate_outcomes = c_component_surrogate_outcomes.union(component)

# subgraph where interventions in edges are removed
interventions_overbar = graph.remove_in_edges(interventions)
interventions_overbar = graph.remove_in_edges(surrogate_interventions)
# Ancestors of surrogate_outcomes in interventions_overbar
Ancestors_surrogate_outcomes = interventions_overbar.ancestors_inclusive(surrogate_outcomes)

# Descendants of interventions in graph
Descendants_interventions = graph.descendants_inclusive(interventions)
Descendants_interventions = graph.descendants_inclusive(surrogate_interventions)

return (Descendants_interventions - surrogate_outcomes).union(
c_component_surrogate_outcomes - Ancestors_surrogate_outcomes
Expand All @@ -88,12 +85,13 @@ def get_transport_nodes(graph: NxMixedGraph) -> Set[Variable]:


def create_transport_diagram(
transport_nodes: Union[Set[Variable], Variable],
*,
nodes: Union[Set[Variable], Variable],
graph: NxMixedGraph,
) -> NxMixedGraph:
"""Create a NxMixedGraph identical to graph but with transport vertices added.
:param transport_nodes: Vertices which have transport nodes pointing to them.
:param nodes: Vertices which have transport nodes pointing to them.
:param graph: The graph of the target domain.
:returns: graph with transport vertices added
"""
Expand All @@ -106,7 +104,7 @@ def create_transport_diagram(
rv.add_directed_edge(u, v)
for u, v in graph.undirected.edges():
rv.add_undirected_edge(u, v)
for node in transport_nodes:
for node in nodes:
transport_node = transport_variable(node)
rv.add_directed_edge(transport_node, node)
return rv
Expand Down Expand Up @@ -136,17 +134,20 @@ def surrogate_to_transport(
:param target_outcomes: A set of target variables for causal effects.
:param target_interventions: A set of interventions for the target domain.
:param graph: The graph of the target domain.
:param intervention_outcome_pairs : A set of Experiments available in each domain.
:param surrogate_outcomes:
:param surrogate_interventions:
:returns: An octuple representing the query transformation of a surrogate outcome query.
"""
if set(surrogate_outcomes) != set(surrogate_interventions):
raise ValueError("Inconsistent surrogate outcome and intervention domains")

transportability_diagrams = {
domain: create_transport_diagram(
graph,
find_transport_vertices(
surrogate_interventions[domain], surrogate_outcomes[domain], graph
graph=graph,
nodes=find_transport_nodes(
surrogate_interventions=surrogate_interventions[domain],
surrogate_outcomes=surrogate_outcomes[domain],
graph=graph
),
)
for domain in surrogate_outcomes
Expand All @@ -172,8 +173,8 @@ def trso_line1(
"""Return the probability in the case where no interventions are present.
:param target_outcomes: A set of nodes that comprise our target outcomes.
:param expression : The distribution in the current domain.
:param transportability_diagram : The graph with transport nodes in this domain.
:param expression: The distribution in the current domain.
:param transportability_diagram: The graph with transport nodes in this domain.
:returns: Sum over the probabilities of nodes other than target outcomes.
"""
return Sum.safe(expression, transportability_diagram.nodes() - target_outcomes)
Expand All @@ -188,9 +189,9 @@ def trso_line2(
"""Restrict the interventions and diagram to only include ancestors of target variables.
:param query: A transport query
:param probability : The distribution in the current domain.
:param domain : current domain
:param outcomes_ancestors : the ancestors of target variables in transportability_diagram
:param probability: The distribution in the current domain.
:param domain: current domain
:param outcomes_ancestors: the ancestors of target variables in transportability_diagram
:returns: Dictionary of modified trso inputs.
"""
new_query = deepcopy(query)
Expand All @@ -208,7 +209,7 @@ def trso_line3(query: TransportQuery, additional_interventions: Set[Variable]) -
"""
:param query: A transport query
:param additional_interventions : interventions to be added to target_interventions
:param additional_interventions: interventions to be added to target_interventions
:returns: dictionary of modified trso inputs.
"""
new_query = deepcopy(query)
Expand All @@ -224,8 +225,8 @@ def trso_line4(
"""Find the trso inputs for each C-component.
:param query: A transport query
:param domain : current domain
:param components : Set of c_components of transportability_diagram without target_interventions
:param domain: current domain
:param components: Set of c_components of transportability_diagram without target_interventions
:returns: Dictionary with components as keys and dictionary of modified trso inputs as values
"""
transportability_diagram = query.transportability_diagrams[domain]
Expand All @@ -245,7 +246,7 @@ def trso_line6(
"""Find the active interventions in each diagram, run trso with active interventions.
:param query: A transport query
:param domain : current domain
:param domain: current domain
:returns:
"""
transportability_diagram = query.transportability_diagrams[domain]
Expand Down Expand Up @@ -339,10 +340,7 @@ def trso(
- target_interventions_overbar.ancestors_inclusive(query.target_outcomes)
)
if additional_interventions:
new_query = trso_line3(
query,
additional_interventions,
)
new_query = trso_line3(query, additional_interventions)
return trso(
query=new_query,
active_interventions=active_interventions,
Expand Down Expand Up @@ -412,12 +410,15 @@ def trso(
active_interventions,
domain,
)

# line10
# FIXME why aren't results collated over all districts? then pick which one to return?
for district in districts:
if not districts_without_interventions.issubset(district):
continue
# district is C' districts should be D[C'], but we chose to return set of nodes instead of subgraph
if len(active_interventions) == 0:
# FIXME is this even possible? doesn't line 6 check this and return something else?
new_available_interventions = dict()
elif any(
is_transport_node(node) for node in transportability_diagram.get_markov_pillow(district)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_algorithm/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from y0.algorithm.transport import (
TARGET_DOMAIN,
TransportQuery,
find_transport_vertices,
find_transport_nodes,
surrogate_to_transport,
transport,
trso,
Expand Down Expand Up @@ -92,15 +92,15 @@ class TestTransport(cases.GraphTestCase):

def test_find_transport_vertices(self):
expected = {X1, Y2}
actual = find_transport_vertices(X1, Y1, tikka_trso_figure_8)
actual = find_transport_nodes(X1, Y1, tikka_trso_figure_8)
self.assertEqual(actual, expected)
expected = {X2}
actual = find_transport_vertices({X2}, {Y2}, tikka_trso_figure_8)
actual = find_transport_nodes({X2}, {Y2}, tikka_trso_figure_8)
self.assertEqual(actual, expected)

# Test for multiple vertices in interventions and surrogate_outcomes
expected = {X1, X2, Y1}
actual = find_transport_vertices({X2, X1}, {Y2, W}, tikka_trso_figure_8)
actual = find_transport_nodes({X2, X1}, {Y2, W}, tikka_trso_figure_8)
self.assertEqual(actual, expected)

def test_surrogate_to_transport(self):
Expand Down

0 comments on commit fbe09ce

Please sign in to comment.