diff --git a/src/y0/r_utils.py b/src/y0/r_utils.py index 0744c1b8..df40cd46 100644 --- a/src/y0/r_utils.py +++ b/src/y0/r_utils.py @@ -1,16 +1,18 @@ """General utilities for :mod:`rpy2`.""" +from __future__ import annotations + import logging from collections.abc import Callable, Iterable from functools import lru_cache, wraps from typing import Any, TypeVar, cast -from rpy2.robjects.packages import importr, isinstalled +from rpy2.robjects.packages import InstalledPackage, InstalledSTPackage, importr, isinstalled from rpy2.robjects.vectors import StrVector from .dsl import Variable -__all__ = ["uses_r"] +__all__ = ["uses_r", "prepare_renv", "prepare_default_renv"] logger = logging.getLogger(__name__) @@ -26,10 +28,11 @@ Func = Callable[..., T] -def prepare_renv(requirements: Iterable[str]) -> None: +def prepare_renv(requirements: Iterable[str]) -> list[InstalledSTPackage | InstalledPackage]: """Ensure the given R packages are installed. - :param requirements: A list of R packages to ensure are installed + :param requirements: A list of R package names to ensure are installed + :returns: A list of R packages .. seealso:: https://rpy2.github.io/doc/v3.4.x/html/introduction.html#installing-packages """ @@ -46,8 +49,7 @@ def prepare_renv(requirements: Iterable[str]) -> None: logger.warning("installing R packages: %s", uninstalled_requirements) utils.install_packages(StrVector(uninstalled_requirements)) - for requirement in requirements: - importr(requirement) + return [importr(requirement) for requirement in requirements] @lru_cache(maxsize=1) diff --git a/tests/test_algorithm/test_counterfactual_transportability.py b/tests/test_algorithm/test_counterfactual_transportability.py index 1086bea9..51c08c5c 100644 --- a/tests/test_algorithm/test_counterfactual_transportability.py +++ b/tests/test_algorithm/test_counterfactual_transportability.py @@ -460,7 +460,7 @@ def test_7(self): test7_in = W @ +X test7_out = {W @ +X} result = get_ancestors_of_counterfactual(event=test7_in, graph=figure_2a_graph) - logger.warning("In test_7: result = " + str(result)) + logger.debug("In test_7: result = " + str(result)) self.assertTrue(variable in test7_out for variable in result) @@ -506,7 +506,7 @@ def test_inconsistent_1(self): """ event = [(Y @ -X, -Y), (Y @ -X, +Y)] result = simplify(event=event, graph=figure_2a_graph) - logger.warning("Result for test_inconsistent_1 is " + str(result)) + logger.debug("Result for test_inconsistent_1 is " + str(result)) self.assertIsNone(simplify(event=event, graph=figure_2a_graph)) def test_inconsistent_2(self): @@ -630,11 +630,11 @@ def test_line_2_1(self): nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y) nonreflexive_variable_to_value_mappings[Y @ -X].add(+Y) - logger.warning( + logger.debug( "In test_line_2_1: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_1: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -652,11 +652,11 @@ def test_line_2_2(self): nonreflexive_variable_to_value_mappings = defaultdict(set) nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y) nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y) - logger.warning( + logger.debug( "In test_line_2_2: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_2: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -674,11 +674,11 @@ def test_line_2_10(self): nonreflexive_variable_to_value_mappings = defaultdict(set) nonreflexive_variable_to_value_mappings[Y @ -X].add(None) nonreflexive_variable_to_value_mappings[Y @ -X].add(None) - logger.warning( + logger.debug( "In test_line_2_10: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_10: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -695,11 +695,11 @@ def test_line_2_3(self): reflexive_variable_to_value_mappings[Y @ -Y].add(+Y) nonreflexive_variable_to_value_mappings = defaultdict(set) - logger.warning( + logger.debug( "In test_line_2_3: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_3: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -721,11 +721,11 @@ def test_line_2_11(self): reflexive_variable_to_value_mappings[Y @ -Y].add(None) nonreflexive_variable_to_value_mappings = defaultdict(set) - logger.warning( + logger.debug( "In test_line_2_11: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_11: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -743,11 +743,11 @@ def test_line_2_4(self): nonreflexive_variable_to_value_mappings = defaultdict(set) - logger.warning( + logger.debug( "In test_line_2_4: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_4: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -764,11 +764,11 @@ def test_line_2_5(self): reflexive_variable_to_value_mappings[Y @ +Y].add(-Y) nonreflexive_variable_to_value_mappings = defaultdict(set) - logger.warning( + logger.debug( "In test_line_2_5: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_5: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -786,11 +786,11 @@ def test_line_2_6(self): nonreflexive_variable_to_value_mappings = defaultdict(set) nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y) nonreflexive_variable_to_value_mappings[Y @ -Z].add(+Y) - logger.warning( + logger.debug( "In test_line_2_6: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_6: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -810,11 +810,11 @@ def test_line_2_7(self): nonreflexive_variable_to_value_mappings = defaultdict(set) - logger.warning( + logger.debug( "In test_line_2_7: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_7: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -841,11 +841,11 @@ def test_line_2_8(self): nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y) nonreflexive_variable_to_value_mappings[Y @ -X].add(None) - logger.warning( + logger.debug( "In test_line_2_8: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_8: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -869,11 +869,11 @@ def test_line_2_9(self): nonreflexive_variable_to_value_mappings = defaultdict(set) nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y) - logger.warning( + logger.debug( "In test_line_2_9: nonreflexive_variable_to_value_mappings = " + str(nonreflexive_variable_to_value_mappings) ) - logger.warning( + logger.debug( "In test_line_2_9: reflexive_variable_to_value_mappings = " + str(reflexive_variable_to_value_mappings) ) @@ -1777,7 +1777,7 @@ def test_transport_district_intervening_on_parents_2(self): ] domain_data = [({X}, PP[Pi1](W, X, Y, Z)), (set(), PP[Pi2](W, X, Y, Z))] expected_result = PP[Pi2](X | Z) * PP[Pi2](Z) - logger.warning( + logger.debug( "In test_transport_district_intervening_on_parents_2: expected_result is " + expected_result.to_latex() ) @@ -2485,8 +2485,8 @@ def test_transport_unconditional_counterfactual_query_1(self): domain_graphs=domain_graphs, domain_data=domain_data, ) - logger.warning("Result_expr = " + result_expr.to_latex()) - logger.warning("Result_event = " + str(result_event)) + logger.debug("Result_expr = " + result_expr.to_latex()) + logger.debug("Result_event = " + str(result_event)) self.assert_expr_equal(expected_result, result_expr) def test_transport_unconditional_counterfactual_query_2(self): @@ -2550,8 +2550,8 @@ def test_transport_unconditional_counterfactual_query_3(self): domain_graphs=domain_graphs, domain_data=domain_data, ) - logger.warning("Result_expr = " + result_expr.to_latex()) - logger.warning("Result_event = " + str(result_event)) + logger.debug("Result_expr = " + result_expr.to_latex()) + logger.debug("Result_event = " + str(result_event)) self.assert_expr_equal(expected_result, result_expr) self.assertCountEqual(event, result_event) # Test sending variables with a value of None into this algorithm @@ -3185,10 +3185,10 @@ def test_transport_conditional_counterfactual_query_1(self): domain_graphs=self.example_1_domain_graphs, domain_data=domain_data, ) - logger.warning("expected_result_expr = " + expected_result_expr.to_latex()) - logger.warning("expected_result_event = " + str(expected_result_event)) - logger.warning("Result_expr = " + result_expr.to_latex()) - logger.warning("Result_event = " + str(result_event)) + logger.debug("expected_result_expr = " + expected_result_expr.to_latex()) + logger.debug("expected_result_event = " + str(expected_result_event)) + logger.debug("Result_expr = " + result_expr.to_latex()) + logger.debug("Result_event = " + str(result_event)) self.assert_expr_equal(expected_result_expr, result_expr) self.assertCountEqual(expected_result_event, result_event) @@ -3239,10 +3239,10 @@ def test_transport_conditional_counterfactual_query_2(self): domain_graphs=self.example_2_domain_graphs, domain_data=domain_data, ) - logger.warning("expected_result_expr = " + expected_result_expr.to_latex()) - logger.warning("expected_result_event = " + str(expected_result_event)) - logger.warning("Result_expr = " + result_expr.to_latex()) - logger.warning("Result_event = " + str(result_event)) + logger.debug("expected_result_expr = " + expected_result_expr.to_latex()) + logger.debug("expected_result_event = " + str(expected_result_event)) + logger.debug("Result_expr = " + result_expr.to_latex()) + logger.debug("Result_event = " + str(result_event)) self.assert_expr_equal(expected_result_expr, result_expr) self.assertCountEqual(expected_result_event, result_event) @@ -3588,10 +3588,10 @@ def test_transport_conditional_counterfactual_query_5(self): domain_graphs=self.example_1_domain_graphs, domain_data=domain_data, ) - logger.warning("expected_result_expr = " + expected_result_expr.to_latex()) - logger.warning("expected_result_event = " + str(expected_result_event)) - logger.warning("Result_expr = " + result_expr.to_latex()) - logger.warning("Result_event = " + str(result_event)) + logger.debug("expected_result_expr = " + expected_result_expr.to_latex()) + logger.debug("expected_result_event = " + str(expected_result_event)) + logger.debug("Result_expr = " + result_expr.to_latex()) + logger.debug("Result_event = " + str(result_event)) self.assert_expr_equal(expected_result_expr, result_expr) self.assertCountEqual(expected_result_event, result_event) @@ -3621,10 +3621,10 @@ def test_transport_conditional_counterfactual_query_6(self): domain_graphs=self.example_1_domain_graphs, domain_data=domain_data, ) - logger.warning("expected_result_expr = " + expected_result_expr.to_latex()) - logger.warning("expected_result_event = " + str(expected_result_event)) - logger.warning("Result_expr = " + result_expr.to_latex()) - logger.warning("Result_event = " + str(result_event)) + logger.debug("expected_result_expr = " + expected_result_expr.to_latex()) + logger.debug("expected_result_event = " + str(expected_result_event)) + logger.debug("Result_expr = " + result_expr.to_latex()) + logger.debug("Result_event = " + str(result_event)) self.assert_expr_equal(expected_result_expr, result_expr) self.assertCountEqual(expected_result_event, result_event) @@ -6036,8 +6036,8 @@ def test_merge_frozen_sets_linked_by_bidirectional_edges(self): input_sets=test_2_inputs, graph=graph_2 ) expected_result_2 = {frozenset([W, Y]), frozenset([X]), frozenset([W1]), frozenset([R, Z])} - logger.warning(str(expected_result_2)) - logger.warning(str(result_2)) + logger.debug(str(expected_result_2)) + logger.debug(str(result_2)) self.assertSetEqual(result_2, expected_result_2) graph_3 = NxMixedGraph.from_edges(directed=[], undirected=[(W, X), (X, Y), (R, Z), (W1, Y)]) result_3 = _merge_frozen_sets_linked_by_bidirectional_edges( diff --git a/tests/test_algorithm/test_falsification.py b/tests/test_algorithm/test_falsification.py index f444939a..a89bb53a 100644 --- a/tests/test_algorithm/test_falsification.py +++ b/tests/test_algorithm/test_falsification.py @@ -1,6 +1,7 @@ """Test falsification of testable implications given a graph.""" import unittest +import warnings import numpy as np import pandas as pd @@ -19,7 +20,8 @@ def test_discrete_graph_falsifications(self): for method in [None, *get_conditional_independence_tests()]: if method == "pearson": continue - with self.subTest(method=method): + with self.subTest(method=method), warnings.catch_warnings(): + warnings.simplefilter(action="ignore", category=FutureWarning) issues = get_graph_falsifications( asia_example.graph, asia_example.data, method=method ) diff --git a/tests/test_algorithm/test_tian_pearl_identify.py b/tests/test_algorithm/test_tian_pearl_identify.py index 8514a21d..d3021ef3 100644 --- a/tests/test_algorithm/test_tian_pearl_identify.py +++ b/tests/test_algorithm/test_tian_pearl_identify.py @@ -240,7 +240,7 @@ def test_identify_1(self): graph=soft_interventions_figure_1b_graph, topo=list(soft_interventions_figure_1b_graph.topological_sort()), ) - logger.warning("Result of identify() call for test_identify_1 is " + result.to_latex()) + logger.debug("Result of identify() call for test_identify_1 is " + result.to_latex()) self.assert_expr_equal(result, PP[Pi1](Z | X1)) def test_identify_2(self): @@ -259,7 +259,7 @@ def test_identify_2(self): graph=soft_interventions_figure_2a_graph, topo=list(soft_interventions_figure_2a_graph.topological_sort()), ) - logger.warning("Result of identify() call for test_identify_2 part 1 is " + str(result1)) + logger.debug("Result of identify() call for test_identify_2 part 1 is " + str(result1)) self.assertIsNone(result1) result2 = identify_district_variables( input_variables=frozenset({Z, R}), @@ -285,7 +285,7 @@ def test_identify_3(self): graph=soft_interventions_figure_2d_graph, topo=list(soft_interventions_figure_2d_graph.topological_sort()), ) - logger.warning("Result of identify() call for test_identify_3 is " + str(result1)) + logger.debug("Result of identify() call for test_identify_3 is " + str(result1)) self.assertIsNone(result1) result2 = identify_district_variables( input_variables=frozenset({Z, R}), @@ -294,7 +294,7 @@ def test_identify_3(self): graph=soft_interventions_figure_3_graph, topo=list(soft_interventions_figure_3_graph.topological_sort()), ) - logger.warning("Result of identify() call for test_identify_3 is " + str(result2)) + logger.debug("Result of identify() call for test_identify_3 is " + str(result2)) self.assertIsNone(result2) def test_identify_4(self): @@ -381,8 +381,8 @@ def test_identify_4_with_population_probabilities(self): graph=tian_pearl_figure_9a_graph, topo=list(tian_pearl_figure_9a_graph.topological_sort()), ) - logger.warning("Result from identify_district_variables: " + result_4.to_latex()) - logger.warning(" Expected result: " + expected_result.to_latex()) + logger.debug("Result from identify_district_variables: " + result_4.to_latex()) + logger.debug(" Expected result: " + expected_result.to_latex()) self.assert_expr_equal(result_4, expected_result) @@ -523,11 +523,11 @@ def test_compute_c_factor_5_with_population_probabilities(self): subgraph_probability=subgraph_probability, graph_topo=topo, ) - logger.warning( + logger.debug( "In test_compute_c_factor_5_with_population_probabilities: expected_result = " + expected_result_5.to_latex() ) - logger.warning( + logger.debug( "In test_compute_c_factor_5_with_population_probabilities: result = " + result_5.to_latex() ) @@ -711,7 +711,7 @@ def test_compute_c_factor_marginalizing_over_topological_successors_part_1(self) graph_probability=Sum.safe(self.result_piece, [W3]), topo=list(tian_pearl_figure_9a_graph.subgraph({W1, W2, X, Y}).topological_sort()), ) - logger.warning( + logger.debug( "In first test of Lemma 4(ii): expecting this result: " + str(self.expected_result_1) ) self.assert_expr_equal(result, self.expected_result_1) @@ -721,10 +721,10 @@ def test_compute_c_factor_marginalizing_over_topological_successors_part_2(self) Source: The example on p. 30 of [Tian03a]_, run initially through [tikka20a]_. """ - logger.warning( + logger.debug( "In second test of Lemma 4(ii): expecting this result: " + str(self.expected_result_2) ) - logger.warning("Expected_result_1 = " + str(self.expected_result_1)) + logger.debug("Expected_result_1 = " + str(self.expected_result_1)) result = compute_c_factor_marginalizing_over_topological_successors( district={Y}, graph_probability=Sum.safe(self.expected_result_1, [W1]), @@ -742,7 +742,7 @@ def test_compute_c_factor_marginalizing_over_topological_successors_part_3(self) graph_probability=Sum.safe(self.result_piece_pp, [W3]), topo=list(tian_pearl_figure_9a_graph.subgraph({W1, W2, X, Y}).topological_sort()), ) - logger.warning( + logger.debug( "In first test of Lemma 4(ii): expecting this result: " + self.expected_result_1_pp.to_latex() ) @@ -753,17 +753,17 @@ def test_compute_c_factor_marginalizing_over_topological_successors_part_4(self) Source: The example on p. 30 of [Tian03a]_, run initially through [tikka20a]_. """ - logger.warning( + logger.debug( "In second test of Lemma 4(ii): expecting this result: " + self.expected_result_2.to_latex() ) - logger.warning("Expected_result_1 = " + self.expected_result_1_pp.to_latex()) + logger.debug("Expected_result_1 = " + self.expected_result_1_pp.to_latex()) result = compute_c_factor_marginalizing_over_topological_successors( district={Y}, graph_probability=Sum.safe(self.expected_result_1_pp, [W1]), topo=list(tian_pearl_figure_9a_graph.subgraph({X, Y}).topological_sort()), ) - logger.warning("Expected result = " + self.expected_result_2_pp.to_latex()) + logger.debug("Expected result = " + self.expected_result_2_pp.to_latex()) self.assert_expr_equal(result, self.expected_result_2_pp) diff --git a/tests/test_causaleffect.py b/tests/test_causaleffect.py index 6e58b20e..c70802c8 100644 --- a/tests/test_causaleffect.py +++ b/tests/test_causaleffect.py @@ -9,7 +9,7 @@ try: from y0.causaleffect import r_get_verma_constraints - from y0.r_utils import CAUSALEFFECT, IGRAPH + from y0.r_utils import CAUSALEFFECT, IGRAPH, prepare_renv from y0.struct import VermaConstraint except ImportError: # rpy2 is not installed missing_rpy2 = True @@ -30,13 +30,10 @@ class TestCausalEffect(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """Make imports for the class.""" - from rpy2.robjects.packages import PackageNotInstalledError, importr - try: - importr(CAUSALEFFECT) - importr(IGRAPH) - except PackageNotInstalledError: - raise unittest.SkipTest("R packages not properly installed.") from None + prepare_renv([CAUSALEFFECT, IGRAPH]) + except Exception as e: + raise unittest.SkipTest(f"R packages not properly installed.\n\n{e}") from None def test_verma_constraint(self): """Test getting the single Verma constraint from the Figure 1A graph.""" diff --git a/tests/test_mutate/test_chain.py b/tests/test_mutate/test_chain.py index b4cc294b..7304f7f4 100644 --- a/tests/test_mutate/test_chain.py +++ b/tests/test_mutate/test_chain.py @@ -1,6 +1,7 @@ """Tests for chain mutations.""" import unittest +import warnings from y0.dsl import A, B, P, Sum, W, X, Y, Z from y0.mutate import bayes_expand, chain_expand, fraction_expand @@ -42,7 +43,10 @@ def test_fraction_expand(self): def test_bayes_expand(self): """Test expanding a conditional using extended Bayes' Theorem.""" - self.assertEqual(P(A, X), bayes_expand(P(A, X))) - self.assertEqual(P(A, X) / Sum[A](P(A, X)), bayes_expand(P(A | X))) - self.assertEqual(P(A, X, Y) / Sum[A](P(A, X, Y)), bayes_expand(P(A | (X, Y)))) - self.assertEqual(P(A, B, X, Y) / Sum[A, B](P(A, B, X, Y)), bayes_expand(P(A & B | (X, Y)))) + with warnings.catch_warnings(): + self.assertEqual(P(A, X), bayes_expand(P(A, X))) + self.assertEqual(P(A, X) / Sum[A](P(A, X)), bayes_expand(P(A | X))) + self.assertEqual(P(A, X, Y) / Sum[A](P(A, X, Y)), bayes_expand(P(A | (X, Y)))) + self.assertEqual( + P(A, B, X, Y) / Sum[A, B](P(A, B, X, Y)), bayes_expand(P(A & B | (X, Y))) + )