Skip to content

Commit

Permalink
imperative docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
adamrupe committed Sep 6, 2024
1 parent a594c17 commit 83c5ff6
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/y0/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"collapse_HCM"]

def HCM_from_lists(*, obs_subunits=[], unobs_subunits=[], obs_units=[], unobs_units=[], edges=[]):
"""Creates a hierarchical causal model from the given node and edge lists.
"""Create a hierarchical causal model from the given node and edge lists.
:param obs_subunits: a list of names for the observed subunit variables
:param unobs_subunits: a list of names for the unobserved subunit variables
Expand All @@ -41,7 +41,7 @@ def HCM_from_lists(*, obs_subunits=[], unobs_subunits=[], obs_units=[], unobs_un
return HCM

def get_observed(HCM):
"""Returns the set of observed variables (both unit and subunit) in the HCM."""
"""Return the set of observed variables (both unit and subunit) in the HCM."""
observed_nodes = set()
for node_name in HCM.nodes():
node = HCM.get_node(node_name)
Expand All @@ -50,33 +50,33 @@ def get_observed(HCM):
return observed_nodes

def get_unobserved(HCM):
"""Returns the set of unobserved variables (both unit and subunit) in the HCM."""
"""Return the set of unobserved variables (both unit and subunit) in the HCM."""
all_nodes = set(HCM.nodes())
return all_nodes - get_observed(HCM)

def get_subunits(HCM):
"""Returns the set of subunit variables in the HCM."""
"""Return the set of subunit variables in the HCM."""
return set(HCM.get_subgraph('cluster_subunits').nodes())

def get_units(HCM):
"""Returns the set of unit variables in the HCM."""
"""Return the set of unit variables in the HCM."""
subunits = get_subunits(HCM)
return set(HCM.nodes()) - subunits

def parents(HCM, node):
"""Returns the set of parent/predecessor variables of the given variable in the HCM."""
"""Return the set of parent/predecessor variables of the given variable in the HCM."""
parents = set(HCM.predecessors(node))
return parents

def _node_string(nodes):
"""Returns a formated string for use in creating Q variables for collapsed HCMs."""
"""Return a formated string for use in creating Q variables for collapsed HCMs."""
s = ""
for node in nodes:
s += node.get_name().lower() + ","
return s[: -1]

def create_Qvar(HCM, subunit_node):
"""Returns a y0 Variable for the unit-level Q variable of the given subunit variable in the HCM."""
"""Return a y0 Variable for the unit-level Q variable of the given subunit variable in the HCM."""
subunit_parents = parents(HCM, subunit_node) & get_subunits(HCM)
parent_str = _node_string(subunit_parents)
if parent_str == '':
Expand All @@ -86,7 +86,7 @@ def create_Qvar(HCM, subunit_node):
return Variable(Q_str)

def direct_unit_descendents(HCM, subunit_node):
"""Returns the set of direct unit descendents of the given subunit variable in the HCM."""
"""Return the set of direct unit descendents of the given subunit variable in the HCM."""
units = get_units(HCM)
subunits = get_subunits(HCM)
descendents = HCM.successors(subunit_node)
Expand All @@ -112,8 +112,8 @@ def direct_unit_descendents(HCM, subunit_node):
return duds

def collapse_HCM(HCM):
"""Returns a collapsed hierarchical causal model.
"""Return a collapsed hierarchical causal model.
:param HCM: pygraphviz AGraph of the hierarchical causal model to be collapsed
:returns: NxMixedGraph
"""
Expand Down

0 comments on commit 83c5ff6

Please sign in to comment.