diff --git a/pyproject.toml b/pyproject.toml index 61cee7d66..5a12a0d55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dependencies = [ "docrep>=0.3.2", "ott-jax>=0.4.3", "cloudpickle>=2.2.0", + "rich>=13.5", ] [project.optional-dependencies] diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 6e065bb33..55eac7102 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import numpy as np from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.solvers.quadratic import gromov_wasserstein +from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr import matplotlib as mpl import matplotlib.pyplot as plt @@ -29,7 +29,13 @@ class OTTOutput(BaseSolverOutput): _NOT_COMPUTED = -1.0 # sentinel value used in `ott` def __init__( - self, output: Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput] + self, + output: Union[ + sinkhorn.SinkhornOutput, + sinkhorn_lr.LRSinkhornOutput, + gromov_wasserstein.GWOutput, + gromov_wasserstein_lr.LRGWOutput, + ], ): super().__init__() self._output = output @@ -218,8 +224,12 @@ def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102 @property def rank(self) -> int: # noqa: D102 - lin_output = self._output if self.is_linear else self._output.linear_state - return len(lin_output.g) if isinstance(lin_output, sinkhorn_lr.LRSinkhornOutput) else -1 + output = self._output.linear_state if isinstance(self._output, gromov_wasserstein.GWOutput) else self._output + return ( + len(output.g) + if isinstance(output, (sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein_lr.LRGWOutput)) + else -1 + ) def _ones(self, n: int) -> ArrayLike: # noqa: D102 return jnp.ones((n,)) diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index 0935b7f2f..f51ac5441 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -8,7 +8,7 @@ from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.solvers.quadratic import gromov_wasserstein +from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d @@ -19,7 +19,12 @@ __all__ = ["SinkhornSolver", "GWSolver"] -OTTSolver_t = Union[sinkhorn.Sinkhorn, sinkhorn_lr.LRSinkhorn, gromov_wasserstein.GromovWasserstein] +OTTSolver_t = Union[ + sinkhorn.Sinkhorn, + sinkhorn_lr.LRSinkhorn, + gromov_wasserstein.GromovWasserstein, + gromov_wasserstein_lr.LRGromovWasserstein, +] OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem] Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]] @@ -243,21 +248,25 @@ def __init__( ): super().__init__(jit=jit) if rank > -1: - linear_solver_kwargs = dict(linear_solver_kwargs) - linear_solver_kwargs.setdefault("gamma", 10) - linear_solver_kwargs.setdefault("gamma_rescale", True) - linear_ot_solver = sinkhorn_lr.LRSinkhorn(rank=rank, **linear_solver_kwargs) + kwargs.setdefault("gamma", 10) + kwargs.setdefault("gamma_rescale", True) initializer = "rank2" if initializer is None else initializer + self._solver = gromov_wasserstein_lr.LRGromovWasserstein( + rank=rank, + initializer=initializer, + kwargs_init=initializer_kwargs, + **kwargs, + ) else: linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs) initializer = None - self._solver = gromov_wasserstein.GromovWasserstein( - rank=rank, - linear_ot_solver=linear_ot_solver, - quad_initializer=initializer, - kwargs_init=initializer_kwargs, - **kwargs, - ) + self._solver = gromov_wasserstein.GromovWasserstein( + rank=rank, + linear_ot_solver=linear_ot_solver, + quad_initializer=initializer, + kwargs_init=initializer_kwargs, + **kwargs, + ) def _prepare( self, diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 4a02346c9..44d010b1d 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -349,7 +349,7 @@ def _( ): problem = self.problems[src, tgt] fun = problem.push if forward else problem.pull - res[src] = fun(data=data, scale_by_marginals=scale_by_marginals) + res[src] = fun(data=data, scale_by_marginals=scale_by_marginals, **kwargs) return res if return_all else res[src] @_apply.register(ExplicitPolicy) @@ -382,7 +382,9 @@ def _( for _src, _tgt in [(src, tgt)] + rest: problem = self.problems[_src, _tgt] fun = problem.push if forward else problem.pull - res[_tgt if forward else _src] = current_mass = fun(current_mass, scale_by_marginals=scale_by_marginals) + res[_tgt if forward else _src] = current_mass = fun( + current_mass, scale_by_marginals=scale_by_marginals, **kwargs + ) return res if return_all else current_mass diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index bb3849702..ee75bd200 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -16,6 +16,7 @@ from ott.solvers.linear.sinkhorn_lr import LRSinkhorn from ott.solvers.quadratic.gromov_wasserstein import GromovWasserstein from ott.solvers.quadratic.gromov_wasserstein import solve as gromov_wasserstein +from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein from moscot._types import ArrayLike, Device_t from moscot.backends.ott import GWSolver, SinkhornSolver @@ -128,12 +129,19 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f assert isinstance(solver.y, Geometry) np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL) + @pytest.mark.skip(reason="TODO") @pytest.mark.parametrize("rank", [-1, 7]) def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None: thresh, eps = 1e-2, 1e-2 - gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)( - QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)) - ) + if rank > -1: + gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)( + QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)) + ) + + else: + gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)( + QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)) + ) solver = GWSolver(rank=rank, epsilon=eps, threshold=thresh) pred = solver(x=x, y=y, tags={"x": "point_cloud", "y": "point_cloud"}) diff --git a/tests/problems/conftest.py b/tests/problems/conftest.py index 55821e620..d5d876e25 100644 --- a/tests/problems/conftest.py +++ b/tests/problems/conftest.py @@ -154,10 +154,11 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData: "gw_unbalanced_correction": False, "ranks": 3, "tolerances": 3e-2, - "warm_start": True, - "linear_solver_kwargs": linear_solver_kwargs2, + # "linear_solver_kwargs": linear_solver_kwargs2, } +gw_args_2 = {**gw_args_2, **linear_solver_kwargs2} + fgw_args_1 = gw_args_1.copy() fgw_args_1["alpha"] = 0.6 @@ -175,6 +176,16 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData: "initializer": "quad_initializer", } +gw_lr_solver_args = { + "epsilon": "epsilon", + "rank": "rank", + "threshold": "threshold", + "min_iterations": "min_iterations", + "max_iterations": "max_iterations", + "initializer_kwargs": "kwargs_init", + "initializer": "initializer", +} + gw_linear_solver_args = { "lse_mode": "lse_mode", "inner_iterations": "inner_iterations", diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index ff56744bf..f2c3d0314 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -16,6 +16,7 @@ geometry_args, gw_linear_solver_args, gw_lr_linear_solver_args, + gw_lr_solver_args, gw_solver_args, pointcloud_args, quad_prob_args, @@ -129,18 +130,21 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], tp = tp.solve(**args_to_check) solver = tp[key].solver.solver - for arg, val in gw_solver_args.items(): + + args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args + for arg, val in args.items(): assert getattr(solver, val) == args_to_check[arg], arg - sinkhorn_solver = solver.linear_ot_solver + sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args + tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): el = ( getattr(sinkhorn_solver, val)[0] if isinstance(getattr(sinkhorn_solver, val), tuple) else getattr(sinkhorn_solver, val) ) - assert el == args_to_check["linear_solver_kwargs"][arg], arg + assert el == tmp_dict[arg], arg quad_prob = tp[key]._solver._problem for arg, val in quad_prob_args.items(): diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 27d2c11f2..3db088ba7 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -29,6 +29,7 @@ geometry_args, gw_linear_solver_args, gw_lr_linear_solver_args, + gw_lr_solver_args, gw_solver_args, pointcloud_args, quad_prob_args, @@ -98,18 +99,20 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin key = ("0", "1") solver = problem[key].solver.solver - for arg, val in gw_solver_args.items(): + args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args + for arg, val in args.items(): assert getattr(solver, val, object()) == args_to_check[arg], arg - sinkhorn_solver = solver.linear_ot_solver + sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args + tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): el = ( getattr(sinkhorn_solver, val)[0] if isinstance(getattr(sinkhorn_solver, val), tuple) else getattr(sinkhorn_solver, val) ) - assert el == args_to_check["linear_solver_kwargs"][arg], arg + assert el == tmp_dict[arg], arg quad_prob = problem[key].solver._problem for arg, val in quad_prob_args.items(): diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 7548f3c83..b4d7db2f3 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -28,6 +28,7 @@ gw_args_2, gw_linear_solver_args, gw_lr_linear_solver_args, + gw_lr_solver_args, gw_solver_args, quad_prob_args, ) @@ -113,20 +114,21 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin problem = problem.solve(**args_to_check) key = ("0", "1") solver = problem[key].solver.solver - for arg, val in gw_solver_args.items(): + args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args + for arg, val in args.items(): assert hasattr(solver, val) assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver + sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args + tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): - assert hasattr(sinkhorn_solver, val) el = ( getattr(sinkhorn_solver, val)[0] if isinstance(getattr(sinkhorn_solver, val), tuple) else getattr(sinkhorn_solver, val) ) - assert el == args_to_check["linear_solver_kwargs"][arg], arg + assert el == tmp_dict[arg], arg quad_prob = problem[key]._solver._problem for arg, val in quad_prob_args.items(): diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 76510221a..77f70dc4d 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -16,6 +16,7 @@ geometry_args, gw_linear_solver_args, gw_lr_linear_solver_args, + gw_lr_solver_args, gw_solver_args, pointcloud_args, quad_prob_args, @@ -130,20 +131,21 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin problem = problem.solve(**args_to_check) solver = problem[key].solver.solver - for arg, val in gw_solver_args.items(): + args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args + for arg, val in args.items(): assert hasattr(solver, val) assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver + sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args + tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): - assert hasattr(sinkhorn_solver, val) el = ( getattr(sinkhorn_solver, val)[0] if isinstance(getattr(sinkhorn_solver, val), tuple) else getattr(sinkhorn_solver, val) ) - assert el == args_to_check["linear_solver_kwargs"][arg], arg + assert el == tmp_dict[arg], arg quad_prob = problem[key]._solver._problem for arg, val in quad_prob_args.items(): diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 4c7aae8ba..808d69f9b 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -17,6 +17,7 @@ geometry_args, gw_linear_solver_args, gw_lr_linear_solver_args, + gw_lr_solver_args, gw_solver_args, pointcloud_args, quad_prob_args, @@ -135,20 +136,21 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str problem = problem.solve(**args_to_check) solver = problem[key].solver.solver - for arg, val in gw_solver_args.items(): + args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args + for arg, val in args.items(): assert hasattr(solver, val) assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver + sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args + tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): - assert hasattr(sinkhorn_solver, val) el = ( getattr(sinkhorn_solver, val)[0] if isinstance(getattr(sinkhorn_solver, val), tuple) else getattr(sinkhorn_solver, val) ) - assert el == args_to_check["linear_solver_kwargs"][arg], arg + assert el == tmp_dict[arg], arg quad_prob = problem[key]._solver._problem for arg, val in quad_prob_args.items(): diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index f26b80026..aea737c5d 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -19,6 +19,7 @@ geometry_args, gw_linear_solver_args, gw_lr_linear_solver_args, + gw_lr_solver_args, gw_solver_args, pointcloud_args, quad_prob_args, @@ -189,20 +190,21 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map key = (0, 1) solver = problem[key].solver.solver - for arg, val in gw_solver_args.items(): + args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args + for arg, val in args.items(): assert hasattr(solver, val) assert getattr(solver, val) == args_to_check[arg], arg - sinkhorn_solver = solver.linear_ot_solver + sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args + tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): - assert hasattr(sinkhorn_solver, val) el = ( getattr(sinkhorn_solver, val)[0] if isinstance(getattr(sinkhorn_solver, val), tuple) else getattr(sinkhorn_solver, val) ) - assert el == args_to_check["linear_solver_kwargs"][arg], arg + assert el == tmp_dict[arg], arg quad_prob = problem[key]._solver._problem for arg, val in quad_prob_args.items(): diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index 858bb1757..394adc2ef 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -18,6 +18,7 @@ geometry_args, gw_linear_solver_args, gw_lr_linear_solver_args, + gw_lr_solver_args, gw_solver_args, pointcloud_args, quad_prob_args, @@ -230,20 +231,21 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi problem = problem.solve(**args_to_check) key = (0, 1) solver = problem[key].solver.solver - for arg, val in gw_solver_args.items(): + args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args + for arg, val in args.items(): assert hasattr(solver, val) assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver + sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args + tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): - assert hasattr(sinkhorn_solver, val) el = ( getattr(sinkhorn_solver, val)[0] if isinstance(getattr(sinkhorn_solver, val), tuple) else getattr(sinkhorn_solver, val) ) - assert el == args_to_check["linear_solver_kwargs"][arg], arg + assert el == tmp_dict[arg], arg quad_prob = problem[key]._solver._problem for arg, val in quad_prob_args.items():