Skip to content

Commit

Permalink
Update/lrgw2 (#613)
Browse files Browse the repository at this point in the history

Co-authored-by: Arina Danilina <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 9, 2023
1 parent 62cdf59 commit b233601
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 50 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies = [
"docrep>=0.3.2",
"ott-jax>=0.4.3",
"cloudpickle>=2.2.0",
"rich>=13.5",
]

[project.optional-dependencies]
Expand Down
18 changes: 14 additions & 4 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import numpy as np
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -29,7 +29,13 @@ class OTTOutput(BaseSolverOutput):
_NOT_COMPUTED = -1.0 # sentinel value used in `ott`

def __init__(
self, output: Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput]
self,
output: Union[
sinkhorn.SinkhornOutput,
sinkhorn_lr.LRSinkhornOutput,
gromov_wasserstein.GWOutput,
gromov_wasserstein_lr.LRGWOutput,
],
):
super().__init__()
self._output = output
Expand Down Expand Up @@ -218,8 +224,12 @@ def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102

@property
def rank(self) -> int: # noqa: D102
lin_output = self._output if self.is_linear else self._output.linear_state
return len(lin_output.g) if isinstance(lin_output, sinkhorn_lr.LRSinkhornOutput) else -1
output = self._output.linear_state if isinstance(self._output, gromov_wasserstein.GWOutput) else self._output
return (
len(output.g)
if isinstance(output, (sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein_lr.LRGWOutput))
else -1
)

def _ones(self, n: int) -> ArrayLike: # noqa: D102
return jnp.ones((n,))
35 changes: 22 additions & 13 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d
Expand All @@ -19,7 +19,12 @@

__all__ = ["SinkhornSolver", "GWSolver"]

OTTSolver_t = Union[sinkhorn.Sinkhorn, sinkhorn_lr.LRSinkhorn, gromov_wasserstein.GromovWasserstein]
OTTSolver_t = Union[
sinkhorn.Sinkhorn,
sinkhorn_lr.LRSinkhorn,
gromov_wasserstein.GromovWasserstein,
gromov_wasserstein_lr.LRGromovWasserstein,
]
OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem]
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]

Expand Down Expand Up @@ -243,21 +248,25 @@ def __init__(
):
super().__init__(jit=jit)
if rank > -1:
linear_solver_kwargs = dict(linear_solver_kwargs)
linear_solver_kwargs.setdefault("gamma", 10)
linear_solver_kwargs.setdefault("gamma_rescale", True)
linear_ot_solver = sinkhorn_lr.LRSinkhorn(rank=rank, **linear_solver_kwargs)
kwargs.setdefault("gamma", 10)
kwargs.setdefault("gamma_rescale", True)
initializer = "rank2" if initializer is None else initializer
self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
else:
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
initializer = None
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)

def _prepare(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _(
):
problem = self.problems[src, tgt]
fun = problem.push if forward else problem.pull
res[src] = fun(data=data, scale_by_marginals=scale_by_marginals)
res[src] = fun(data=data, scale_by_marginals=scale_by_marginals, **kwargs)
return res if return_all else res[src]

@_apply.register(ExplicitPolicy)
Expand Down Expand Up @@ -382,7 +382,9 @@ def _(
for _src, _tgt in [(src, tgt)] + rest:
problem = self.problems[_src, _tgt]
fun = problem.push if forward else problem.pull
res[_tgt if forward else _src] = current_mass = fun(current_mass, scale_by_marginals=scale_by_marginals)
res[_tgt if forward else _src] = current_mass = fun(
current_mass, scale_by_marginals=scale_by_marginals, **kwargs
)

return res if return_all else current_mass

Expand Down
14 changes: 11 additions & 3 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ott.solvers.linear.sinkhorn_lr import LRSinkhorn
from ott.solvers.quadratic.gromov_wasserstein import GromovWasserstein
from ott.solvers.quadratic.gromov_wasserstein import solve as gromov_wasserstein
from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein

from moscot._types import ArrayLike, Device_t
from moscot.backends.ott import GWSolver, SinkhornSolver
Expand Down Expand Up @@ -128,12 +129,19 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f
assert isinstance(solver.y, Geometry)
np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)

@pytest.mark.skip(reason="TODO")
@pytest.mark.parametrize("rank", [-1, 7])
def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None:
thresh, eps = 1e-2, 1e-2
gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)
if rank > -1:
gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)

else:
gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)

solver = GWSolver(rank=rank, epsilon=eps, threshold=thresh)
pred = solver(x=x, y=y, tags={"x": "point_cloud", "y": "point_cloud"})
Expand Down
15 changes: 13 additions & 2 deletions tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"gw_unbalanced_correction": False,
"ranks": 3,
"tolerances": 3e-2,
"warm_start": True,
"linear_solver_kwargs": linear_solver_kwargs2,
# "linear_solver_kwargs": linear_solver_kwargs2,
}

gw_args_2 = {**gw_args_2, **linear_solver_kwargs2}

fgw_args_1 = gw_args_1.copy()
fgw_args_1["alpha"] = 0.6

Expand All @@ -175,6 +176,16 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"initializer": "quad_initializer",
}

gw_lr_solver_args = {
"epsilon": "epsilon",
"rank": "rank",
"threshold": "threshold",
"min_iterations": "min_iterations",
"max_iterations": "max_iterations",
"initializer_kwargs": "kwargs_init",
"initializer": "initializer",
}

gw_linear_solver_args = {
"lse_mode": "lse_mode",
"inner_iterations": "inner_iterations",
Expand Down
10 changes: 7 additions & 3 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
geometry_args,
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_lr_solver_args,
gw_solver_args,
pointcloud_args,
quad_prob_args,
Expand Down Expand Up @@ -129,18 +130,21 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData],
tp = tp.solve(**args_to_check)

solver = tp[key].solver.solver
for arg, val in gw_solver_args.items():

args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert getattr(solver, val) == args_to_check[arg], arg

sinkhorn_solver = solver.linear_ot_solver
sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check
for arg, val in lin_solver_args.items():
el = (
getattr(sinkhorn_solver, val)[0]
if isinstance(getattr(sinkhorn_solver, val), tuple)
else getattr(sinkhorn_solver, val)
)
assert el == args_to_check["linear_solver_kwargs"][arg], arg
assert el == tmp_dict[arg], arg

quad_prob = tp[key]._solver._problem
for arg, val in quad_prob_args.items():
Expand Down
9 changes: 6 additions & 3 deletions tests/problems/generic/test_fgw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
geometry_args,
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_lr_solver_args,
gw_solver_args,
pointcloud_args,
quad_prob_args,
Expand Down Expand Up @@ -98,18 +99,20 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
key = ("0", "1")

solver = problem[key].solver.solver
for arg, val in gw_solver_args.items():
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert getattr(solver, val, object()) == args_to_check[arg], arg

sinkhorn_solver = solver.linear_ot_solver
sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check
for arg, val in lin_solver_args.items():
el = (
getattr(sinkhorn_solver, val)[0]
if isinstance(getattr(sinkhorn_solver, val), tuple)
else getattr(sinkhorn_solver, val)
)
assert el == args_to_check["linear_solver_kwargs"][arg], arg
assert el == tmp_dict[arg], arg

quad_prob = problem[key].solver._problem
for arg, val in quad_prob_args.items():
Expand Down
10 changes: 6 additions & 4 deletions tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
gw_args_2,
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_lr_solver_args,
gw_solver_args,
quad_prob_args,
)
Expand Down Expand Up @@ -113,20 +114,21 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
problem = problem.solve(**args_to_check)
key = ("0", "1")
solver = problem[key].solver.solver
for arg, val in gw_solver_args.items():
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_ot_solver
sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check
for arg, val in lin_solver_args.items():
assert hasattr(sinkhorn_solver, val)
el = (
getattr(sinkhorn_solver, val)[0]
if isinstance(getattr(sinkhorn_solver, val), tuple)
else getattr(sinkhorn_solver, val)
)
assert el == args_to_check["linear_solver_kwargs"][arg], arg
assert el == tmp_dict[arg], arg

quad_prob = problem[key]._solver._problem
for arg, val in quad_prob_args.items():
Expand Down
10 changes: 6 additions & 4 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
geometry_args,
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_lr_solver_args,
gw_solver_args,
pointcloud_args,
quad_prob_args,
Expand Down Expand Up @@ -130,20 +131,21 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
problem = problem.solve(**args_to_check)

solver = problem[key].solver.solver
for arg, val in gw_solver_args.items():
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_ot_solver
sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check
for arg, val in lin_solver_args.items():
assert hasattr(sinkhorn_solver, val)
el = (
getattr(sinkhorn_solver, val)[0]
if isinstance(getattr(sinkhorn_solver, val), tuple)
else getattr(sinkhorn_solver, val)
)
assert el == args_to_check["linear_solver_kwargs"][arg], arg
assert el == tmp_dict[arg], arg

quad_prob = problem[key]._solver._problem
for arg, val in quad_prob_args.items():
Expand Down
10 changes: 6 additions & 4 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
geometry_args,
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_lr_solver_args,
gw_solver_args,
pointcloud_args,
quad_prob_args,
Expand Down Expand Up @@ -135,20 +136,21 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str
problem = problem.solve(**args_to_check)

solver = problem[key].solver.solver
for arg, val in gw_solver_args.items():
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_ot_solver
sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check
for arg, val in lin_solver_args.items():
assert hasattr(sinkhorn_solver, val)
el = (
getattr(sinkhorn_solver, val)[0]
if isinstance(getattr(sinkhorn_solver, val), tuple)
else getattr(sinkhorn_solver, val)
)
assert el == args_to_check["linear_solver_kwargs"][arg], arg
assert el == tmp_dict[arg], arg

quad_prob = problem[key]._solver._problem
for arg, val in quad_prob_args.items():
Expand Down
10 changes: 6 additions & 4 deletions tests/problems/spatio_temporal/test_spatio_temporal_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
geometry_args,
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_lr_solver_args,
gw_solver_args,
pointcloud_args,
quad_prob_args,
Expand Down Expand Up @@ -189,20 +190,21 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map

key = (0, 1)
solver = problem[key].solver.solver
for arg, val in gw_solver_args.items():
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
assert getattr(solver, val) == args_to_check[arg], arg

sinkhorn_solver = solver.linear_ot_solver
sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check
for arg, val in lin_solver_args.items():
assert hasattr(sinkhorn_solver, val)
el = (
getattr(sinkhorn_solver, val)[0]
if isinstance(getattr(sinkhorn_solver, val), tuple)
else getattr(sinkhorn_solver, val)
)
assert el == args_to_check["linear_solver_kwargs"][arg], arg
assert el == tmp_dict[arg], arg

quad_prob = problem[key]._solver._problem
for arg, val in quad_prob_args.items():
Expand Down
Loading

0 comments on commit b233601

Please sign in to comment.