diff --git a/src/moscot/_types.py b/src/moscot/_types.py index 1c60884a2..2871482fb 100644 --- a/src/moscot/_types.py +++ b/src/moscot/_types.py @@ -2,6 +2,7 @@ from typing import Any, Literal, Mapping, Optional, Sequence, Union import numpy as np +from ott.initializers.quadratic.initializers import BaseQuadraticInitializer # TODO(michalk8): polish @@ -20,8 +21,9 @@ SinkFullRankInit = Literal["default", "gaussian", "sorting"] LRInitializer_t = Literal["random", "rank2", "k-means", "generalized-k-means"] + SinkhornInitializer_t = Optional[Union[SinkFullRankInit, LRInitializer_t]] -QuadInitializer_t = Optional[LRInitializer_t] +QuadInitializer_t = Optional[Union[LRInitializer_t, BaseQuadraticInitializer]] Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t] ProblemStage_t = Literal["prepared", "solved"] diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index dba12b5ac..4784cef25 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -23,6 +23,7 @@ import jax.numpy as jnp import numpy as np from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud +from ott.initializers.quadratic import initializers as quad_initializers from ott.neural.datasets import OTData, OTDataset from ott.neural.methods.flows import dynamics, genot from ott.neural.networks.layers import time_encoder @@ -409,13 +410,18 @@ def __init__( **kwargs, ) else: - linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs) - initializer = None + linear_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs) + if initializer is None: + initializer = quad_initializers.QuadraticInitializer() + if isinstance(initializer, str): + raise ValueError( + "Expected `initializer` to be an instance of `ott.initializers.quadratic.BaseQuadraticInitializer`," + f"found `{initializer}`." + ) + initializer = functools.partial(initializer, **initializer_kwargs) self._solver = gromov_wasserstein.GromovWasserstein( - rank=rank, - linear_ot_solver=linear_ot_solver, - quad_initializer=initializer, - kwargs_init=initializer_kwargs, + linear_solver=linear_solver, + initializer=initializer, **kwargs, ) @@ -435,7 +441,7 @@ def _prepare( cost_matrix_rank: Optional[int] = None, time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None, # problem - alpha: float = 0.5, + alpha: Optional[float] = None, **kwargs: Any, ) -> quadratic_problem.QuadraticProblem: self._a = a @@ -456,6 +462,13 @@ def _prepare( geom_kwargs["cost_matrix_rank"] = cost_matrix_rank geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs) geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs) + if alpha is None: + alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided + if alpha <= 0.0: + raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.") + if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None): + raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.") + if alpha == 1.0 or xy is None: # GW # arbitrary fused penalty; must be positive geom_xy, fused_penalty = None, 1.0 diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 708b35d44..83e1d13d7 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -430,24 +430,8 @@ def solve( solver_class = backends.get_solver( self.problem_kind, solver_name=solver_name, backend=backend, return_class=True ) - init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) - # if linear problem, then alpha is 0.0 by default - # if quadratic problem, then alpha is 1.0 by default - alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0) - if alpha < 0.0 or alpha > 1.0: - raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.") - if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)): - raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.") - if self.problem_kind == "quadratic": - if self.x is None or self.y is None: - raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.") - if alpha != 1.0 and self.xy is None: # means FGW case - raise ValueError( - "`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class." - ) - if alpha == 1.0 and self.xy is not None: - raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.") + init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) self._solver = solver_class(**init_kwargs) # note that the solver call consists of solver._prepare and solver._solve diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index 3962c53d1..1169e3f10 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -99,7 +99,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, eps: Optional[float], jit: bool thresh = 1e-2 pc_x, pc_y = PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps) prob = quadratic_problem.QuadraticProblem(pc_x, pc_y) - sol = GromovWasserstein(epsilon=eps, threshold=thresh) + sol = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn()) solver = jax.jit(sol, static_argnames=["threshold", "epsilon"]) if jit else sol gt = solver(prob) @@ -130,7 +130,7 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f problem = QuadraticProblem( geom_xx=Geometry(cost_matrix=x_cost, epsilon=eps), geom_yy=Geometry(cost_matrix=y_cost, epsilon=eps) ) - gt = GromovWasserstein(epsilon=eps, threshold=thresh)(problem) + gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())(problem) solver = GWSolver(epsilon=eps, threshold=thresh) pred = solver( @@ -157,7 +157,7 @@ def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None: ) else: - gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)( + gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())( QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)) ) @@ -183,7 +183,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, xy: Geom_t, eps: Optional[float thresh = 1e-2 xx, yy = xy - ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh) + ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn()) problem = quadratic_problem.QuadraticProblem( geom_xx=PointCloud(x, epsilon=eps), geom_yy=PointCloud(y, epsilon=eps), @@ -218,7 +218,7 @@ def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None: thresh, eps = 5e-2, 1e-1 xx, yy = xy - ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh) + ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn()) problem = quadratic_problem.QuadraticProblem( geom_xx=PointCloud(x, epsilon=eps), geom_yy=PointCloud(y, epsilon=eps), @@ -256,7 +256,7 @@ def test_epsilon( geom_xy=Geometry(cost_matrix=xy_cost, epsilon=eps), fused_penalty=alpha_to_fused_penalty(alpha), ) - gt = GromovWasserstein(epsilon=eps, threshold=thresh)(problem) + gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())(problem) solver = GWSolver(epsilon=eps, threshold=thresh) pred = solver( @@ -398,7 +398,5 @@ def test_plot_errors_sink(self, x: Geom_t, y: Geom_t): out.plot_errors() def test_plot_errors_gw(self, x: Geom_t, y: Geom_t): - out = GWSolver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), store_inner_errors=True)( - a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y - ) + out = GWSolver(store_inner_errors=True)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y) out.plot_errors() diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index 5e17b6dec..869b0ea7f 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -1,3 +1,4 @@ +import re from typing import Literal, Optional, Tuple import pytest @@ -29,6 +30,29 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData): assert isinstance(prob.solution, BaseDiscreteSolverOutput) + @pytest.mark.parametrize( + ("kind", "rank"), + [ + ("linear", -1), + ("linear", 5), + ("quadratic", -1), + ("quadratic", 5), + ], + ) + def test_unrecognized_args( + self, adata_x: AnnData, adata_y: AnnData, kind: Literal["linear", "quadratic"], rank: int + ): + prob = OTProblem(adata_x, adata_y) + data = { + "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}, + } + if "quadratic" in kind: + data["x"] = {"attr": "X"} + data["y"] = {"attr": "X"} + + with pytest.raises(TypeError): + prob.prepare(**data).solve(epsilon=5e-1, rank=rank, dummy=42) + @pytest.mark.fast def test_output(self, adata_x: AnnData, x: Geom_t): problem = OTProblem(adata_x) @@ -346,3 +370,35 @@ def test_set_graph_xy_test_t(self, adata_x: AnnData, adata_y: AnnData, t: float) assert pushed_0.shape == pushed_1.shape assert np.all(np.abs(pushed_0 - pushed_1).sum() > np.abs(pushed_2 - pushed_1).sum()) assert np.all(np.abs(pushed_0 - pushed_2).sum() > np.abs(pushed_1 - pushed_2).sum()) + + @pytest.mark.parametrize( + ("attrs", "alpha", "raise_msg"), + [ + ({"xy"}, 0.5, "type-error"), + ({"xy", "x", "y"}, 0, re.escape("Expected `alpha` to be in interval `(0, 1]`, found")), + ({"xy", "x", "y"}, 1.1, re.escape("Expected `alpha` to be in interval `(0, 1]`, found")), + ({"xy", "x", "y"}, 0.5, None), + ({"x", "y"}, 1.0, None), + ({"x", "y"}, 0.5, re.escape("Expected `xy` to be `None` if `alpha` is not 1.0, found")), + ], + ) + def test_xy_alpha_raises(self, adata_x: AnnData, adata_y: AnnData, attrs, alpha, raise_msg): + prob = OTProblem(adata_x, adata_y) + data = { + "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"} if "xy" in attrs else {}, + "x": {"attr": "X"} if "x" in attrs else {}, + "y": {"attr": "X"} if "y" in attrs else {}, + } + prob = prob.prepare( + **data, + ) + if raise_msg is not None: + if raise_msg == "type-error": + with pytest.raises(TypeError): + prob.solve(epsilon=5e-1, alpha=alpha) + else: + with pytest.raises(ValueError, match=raise_msg): + prob.solve(epsilon=5e-1, alpha=alpha) + else: + prob.solve(epsilon=5e-1, alpha=alpha) + assert isinstance(prob.solution, BaseDiscreteSolverOutput) diff --git a/tests/problems/conftest.py b/tests/problems/conftest.py index 9ee32e693..c7f3c4784 100644 --- a/tests/problems/conftest.py +++ b/tests/problems/conftest.py @@ -183,9 +183,8 @@ def marginal_keys(request): "threshold": "threshold", "min_iterations": "min_iterations", "max_iterations": "max_iterations", - "initializer_kwargs": "kwargs_init", - "warm_start": "_warm_start", - "initializer": "quad_initializer", + "warm_start": "warm_start", + "initializer": "initializer", } gw_lr_solver_args = { diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index b048fa9cf..5d444db30 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -1,5 +1,5 @@ from contextlib import nullcontext -from typing import Any, Literal, Mapping, Optional, Tuple +from typing import Any, Callable, Literal, Mapping, Optional, Tuple import pytest @@ -144,12 +144,14 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], tp = tp.solve(**args_to_check) solver = tp[key].solver.solver - args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): - assert getattr(solver, val) == args_to_check[arg], arg + if arg == "initializer" and args_to_check["rank"] == -1: + assert isinstance(getattr(solver, val), Callable) + else: + assert getattr(solver, val) == args_to_check[arg], arg - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 3ee7b4f30..4595cbbab 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Mapping +from typing import Any, Callable, Literal, Mapping import pytest @@ -112,9 +112,12 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin solver = problem[key].solver.solver args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): - assert getattr(solver, val, object()) == args_to_check[arg], arg + if args_to_check["rank"] == -1 and arg == "initializer": + assert isinstance(getattr(solver, val), Callable) + else: + assert getattr(solver, val, object()) == args_to_check[arg], arg - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -342,7 +345,7 @@ def test_passing_ott_kwargs_linear(self, adata_space_rotate: AnnData, memory: in }, ) - sinkhorn_solver = problem[("0", "1")].solver.solver.linear_ot_solver + sinkhorn_solver = problem[("0", "1")].solver.solver.linear_solver anderson = sinkhorn_solver.anderson assert isinstance(anderson, acceleration.AndersonAcceleration) diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 7ae0cb7e5..064923cf3 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Mapping +from typing import Any, Callable, Literal, Mapping import pytest @@ -117,9 +117,12 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer" and args_to_check["rank"] == -1: + assert isinstance(getattr(solver, val), Callable) + else: + assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -307,7 +310,7 @@ def test_passing_ott_kwargs_linear(self, adata_space_rotate: AnnData, memory: in }, ) - sinkhorn_solver = problem[("0", "1")].solver.solver.linear_ot_solver + sinkhorn_solver = problem[("0", "1")].solver.solver.linear_solver anderson = sinkhorn_solver.anderson assert isinstance(anderson, acceleration.AndersonAcceleration) diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 0d1b21a16..47202ffe0 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -75,10 +75,16 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str): assert ref == reference assert isinstance(ap[prob_key], ap._base_problem_type) - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( - ("epsilon", "alpha", "rank", "initializer"), - [(1, 0.9, -1, None), (1, 0.5, 10, "random"), (1, 0.5, 10, "rank2"), (0.1, 0.1, -1, None)], + ("epsilon", "alpha", "rank", "initializer", "should_raise"), + [ + (1, 0.9, -1, None, False), + (1, 0.5, 10, "random", False), + (1, 0.5, 10, "rank2", False), + (0.1, 0.1, -1, None, False), + (0.1, -0.1, -1, None, True), # Invalid alpha + (0.1, 1.1, -1, None, True), # Invalid alpha + ], ) def test_solve_balanced( self, @@ -87,6 +93,7 @@ def test_solve_balanced( alpha: float, rank: int, initializer: Optional[Literal["random", "rank2"]], + should_raise: bool, ): kwargs = {} if rank > -1: @@ -95,22 +102,23 @@ def test_solve_balanced( # kwargs["kwargs_init"] = {"key": 0} # kwargs["key"] = 0 return # TODO(@MUCDK) fix after refactoring - ap = ( - AlignmentProblem(adata=adata_space_rotate) - .prepare(batch_key="batch") - .solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) - ) - for prob_key in ap: - assert ap[prob_key].solution.rank == rank - if initializer != "random": # TODO: is this valid? - assert ap[prob_key].solution.converged - - # TODO(michalk8): use np.testing - assert np.allclose(*(sol.cost for sol in ap.solutions.values())) - assert np.all([sol.converged for sol in ap.solutions.values()]) - np.testing.assert_array_equal( - [np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True - ) + ap = AlignmentProblem(adata=adata_space_rotate).prepare(batch_key="batch") + if should_raise: + with pytest.raises(ValueError, match=r"Expected `alpha`"): + ap.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) + else: + ap = ap.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) + for prob_key in ap: + assert ap[prob_key].solution.rank == rank + if initializer != "random": # TODO: is this valid? + assert ap[prob_key].solution.converged + + # TODO(michalk8): use np.testing + assert np.allclose(*(sol.cost for sol in ap.solutions.values())) + assert np.all([sol.converged for sol in ap.solutions.values()]) + np.testing.assert_array_equal( + [np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True + ) def test_solve_unbalanced(self, adata_space_rotate: AnnData): tau_a, tau_b = [0.8, 1] @@ -192,7 +200,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin assert hasattr(solver, val) assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index a51bf4ed9..4ac6266ed 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -1,5 +1,6 @@ +import re from pathlib import Path -from typing import Any, List, Literal, Mapping, Optional, Union +from typing import Any, Callable, List, Literal, Mapping, Optional, Union import pytest @@ -232,9 +233,12 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer" and args_to_check["rank"] == -1: + assert isinstance(getattr(solver, val), Callable) + else: + assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -301,14 +305,14 @@ def test_problem_type( assert isinstance(sol._output, solution_kind) @pytest.mark.parametrize( - ("sc_attr", "alpha"), + ("sc_attr", "alpha", "raise_msg"), [ - (None, 0.5), - ({"attr": "X"}, 0), + (None, 0.5, re.escape("Expected `alpha` to be 0 for a `linear problem`.")), + ({"attr": "X"}, 0, re.escape("Expected `alpha` to be in interval `(0, 1]`, found `0`.")), ], ) def test_problem_type_corner_cases( - self, adata_mapping: AnnData, sc_attr: Optional[Mapping[str, str]], alpha: Optional[float] + self, adata_mapping: AnnData, sc_attr: Optional[Mapping[str, str]], alpha: Optional[float], raise_msg: str ): # initialize and prepare the MappingProblem adataref, adatasp = _adata_spatial_split(adata_mapping) @@ -316,5 +320,5 @@ def test_problem_type_corner_cases( mp = mp.prepare(batch_key="batch", sc_attr=sc_attr) # we test two incompatible combinations of `sc_attr` and `alpha` - with pytest.raises(ValueError, match=r"^Expected `alpha`"): + with pytest.raises(ValueError, match=raise_msg): mp.solve(alpha=alpha) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index 3b68561ba..ac661391d 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -60,7 +60,7 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData): assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys - @pytest.mark.skip(reason="unbalanced does not work yet") + @pytest.mark.skip(reason="unbalanced does not work yet: https://github.com/ott-jax/ott/issues/519") def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): taus = [9e-1, 1e-2] problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal) @@ -200,7 +200,7 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map assert hasattr(solver, val) assert getattr(solver, val) == args_to_check[arg], arg - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index 5375c621e..efe9e8a4b 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping +from typing import Any, Callable, List, Mapping import pytest @@ -233,9 +233,12 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer" and args_to_check["rank"] == -1: + assert isinstance(getattr(solver, val), Callable) + else: + assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items():