Skip to content

Commit

Permalink
Deps/update ott (#557)
Browse files Browse the repository at this point in the history
* update notebooks

* update ottjax

* adapt to ott-jax=0.4.1

* uniformly allow none for epsilon

* fix regTICost

* incorporate feedback
  • Loading branch information
MUCDK authored Aug 7, 2023
1 parent 1b03127 commit 4ffcf14
Show file tree
Hide file tree
Showing 22 changed files with 109 additions and 37 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ dependencies = [
"anndata>=0.9.1",
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"ott-jax==0.4.0",
"docrep>=0.3.2",
"ott-jax>=0.4.3",
"cloudpickle>=2.2.0",
]

Expand Down
6 changes: 3 additions & 3 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Optional
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp
import scipy.sparse as sp
from ott.geometry import geometry, pointcloud
from ott.geometry import epsilon_scheduler, geometry, pointcloud
from ott.tools import sinkhorn_divergence as sdiv

from moscot._logging import logger
Expand All @@ -17,7 +17,7 @@ def sinkhorn_divergence(
point_cloud_2: ArrayLike,
a: Optional[ArrayLike] = None,
b: Optional[ArrayLike] = None,
epsilon: Optional[float] = 1e-1,
epsilon: Union[float, epsilon_scheduler.Epsilon] = 1e-1,
scale_cost: ScaleCost_t = 1.0,
**kwargs: Any,
) -> float:
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, jit: bool = True):
def _create_geometry(
self,
x: TaggedArray,
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Scale_t = 1.0,
batch_size: Optional[int] = None,
Expand Down Expand Up @@ -164,7 +164,7 @@ def _prepare(
x: Optional[TaggedArray] = None,
y: Optional[TaggedArray] = None,
# geometry
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
relative_epsilon: Optional[bool] = None,
batch_size: Optional[int] = None,
scale_cost: Scale_t = 1.0,
Expand Down Expand Up @@ -265,7 +265,7 @@ def _prepare(
x: Optional[TaggedArray] = None,
y: Optional[TaggedArray] = None,
# geometry
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
relative_epsilon: Optional[bool] = None,
batch_size: Optional[int] = None,
scale_cost: Scale_t = 1.0,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/cross_modality/_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def prepare(
def solve( # type: ignore[override]
self,
alpha: Optional[float] = 1.0,
epsilon: Optional[float] = 1e-2,
epsilon: float = 1e-2,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def prepare(

def solve(
self,
epsilon: Optional[float] = 1e-3,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down Expand Up @@ -374,7 +374,7 @@ def set_quad_defaults(z: Union[str, Mapping[str, Any]]) -> Dict[str, str]:
def solve(
self,
alpha: float = 1.0,
epsilon: Optional[float] = 1e-3,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/space/_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def prepare(
def solve(
self,
alpha: float = 0.5,
epsilon: Optional[float] = 1e-2,
epsilon: float = 1e-2,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/space/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def prepare(
def solve(
self,
alpha: float = 0.5,
epsilon: Optional[float] = 1e-2,
epsilon: float = 1e-2,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down
4 changes: 3 additions & 1 deletion src/moscot/problems/spatiotemporal/_spatio_temporal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import types
from typing import Any, Literal, Mapping, Optional, Tuple, Type, Union

from ott.geometry import epsilon_scheduler

from anndata import AnnData

from moscot import _constants
Expand Down Expand Up @@ -157,7 +159,7 @@ def prepare(
def solve(
self,
alpha: float = 0.5,
epsilon: Optional[float] = 1e-3,
epsilon: Union[float, epsilon_scheduler.Epsilon] = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def prepare(

def solve(
self,
epsilon: Optional[float] = 1e-3,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down Expand Up @@ -395,7 +395,7 @@ def prepare(
def solve(
self,
alpha: float = 0.5,
epsilon: Optional[float] = 1e-3,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
Expand Down
2 changes: 0 additions & 2 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def test_solver_rank(self, y: Geom_t, rank: Optional[int], initializer: str):
np.testing.assert_allclose(solver._problem.geom.cost_matrix, problem.geom.cost_matrix, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)

# TODO(michalk8): remove when new ott-jax version comes out
@pytest.mark.xfail(reason="broken on ott-jax==0.4.0")
@pytest.mark.parametrize(
("rank", "cost_fn"), [(2, costs.Euclidean()), (3, costs.SqPNorm(p=1.5)), (5, costs.ElasticL1(0.1))]
)
Expand Down
Binary file modified tests/data/alignment_solutions.pkl
Binary file not shown.
Binary file modified tests/data/mapping_solutions.pkl
Binary file not shown.
7 changes: 3 additions & 4 deletions tests/problems/base/test_general_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_mat
prob.set_xy(cost_matrix, tag=tag)
assert isinstance(prob.xy.data_src, np.ndarray)
assert prob.xy.data_tgt is None

prob = prob.solve(max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
prob = prob.solve(epsilon=1.0, max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
np.testing.assert_equal(prob.xy.data_src, cost_matrix.to_numpy())

@pytest.mark.parametrize("tag", ["cost_matrix", "kernel"])
Expand All @@ -81,7 +80,7 @@ def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matr
assert isinstance(prob.x.data_src, np.ndarray)
assert prob.x.data_tgt is None

prob = prob.solve(max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
prob = prob.solve(epsilon=1.0, max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
np.testing.assert_equal(prob.x.data_src, cost_matrix.to_numpy())

@pytest.mark.parametrize("tag", ["cost_matrix", "kernel"])
Expand All @@ -100,7 +99,7 @@ def test_set_y(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matr
assert isinstance(prob.y.data_src, np.ndarray)
assert prob.y.data_tgt is None

prob = prob.solve(max_iterations=5)
prob = prob.solve(epsilon=1.0, max_iterations=5)
np.testing.assert_equal(prob.y.data_src, cost_matrix.to_numpy())

def test_set_xy_change_problem_kind(self, adata_x: AnnData, adata_y: AnnData):
Expand Down
11 changes: 10 additions & 1 deletion tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import numpy as np
from ott.geometry import epsilon_scheduler

from anndata import AnnData

Expand Down Expand Up @@ -148,7 +149,15 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData],

geom = quad_prob.geom_xx
for arg, val in geometry_args.items():
assert getattr(geom, val) == args_to_check[arg], arg
assert hasattr(geom, val)
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert isinstance(eps_processed, epsilon_scheduler.Epsilon)
assert eps_processed.target == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]

geom = quad_prob.geom_xy
for arg, val in pointcloud_args.items():
Expand Down
15 changes: 12 additions & 3 deletions tests/problems/generic/test_fgw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from ott.geometry import epsilon_scheduler
from ott.geometry.costs import (
Cosine,
ElasticL1,
Expand Down Expand Up @@ -117,7 +118,15 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin

geom = quad_prob.geom_xx
for arg, val in geometry_args.items():
assert getattr(geom, val, object()) == args_to_check[arg], arg
assert hasattr(geom, val)
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert isinstance(eps_processed, epsilon_scheduler.Epsilon)
assert eps_processed.target == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]

geom = quad_prob.geom_xy
for arg, val in pointcloud_args.items():
Expand Down Expand Up @@ -159,8 +168,8 @@ def test_set_xy(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"xy": {"p": 5}, "x": {"p": 3}, "y": {"p": 4}}),
("elastic_l1", ElasticL1, {"gamma": 1.1}),
("elastic_stvs", ElasticSTVS, {"gamma": 1.2}),
("elastic_l1", ElasticL1, {"scaling_reg": 1.1}),
("elastic_stvs", ElasticSTVS, {"scaling_reg": 1.2}),
],
)
def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, cost_kwargs: CostKwargs_t):
Expand Down
14 changes: 11 additions & 3 deletions tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from ott.geometry import epsilon_scheduler
from ott.geometry.costs import (
Cosine,
ElasticL1,
Expand Down Expand Up @@ -135,7 +136,14 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
geom = quad_prob.geom_xx
for arg, val in geometry_args.items():
assert hasattr(geom, val)
assert getattr(geom, val) == args_to_check[arg]
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert isinstance(eps_processed, epsilon_scheduler.Epsilon)
assert eps_processed.target == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]

@pytest.mark.fast()
@pytest.mark.parametrize(
Expand All @@ -146,8 +154,8 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"x": {"p": 3}, "y": {"p": 4}}),
("elastic_l1", ElasticL1, {"x": {"gamma": 3}, "y": {"gamma": 4}}),
("elastic_stvs", ElasticSTVS, {"x": {"gamma": 3}, "y": {"gamma": 4}}),
("elastic_l1", ElasticL1, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
("elastic_stvs", ElasticSTVS, {"x": {"scaling_reg": 3}, "y": {"scaling_reg": 4}}),
],
)
def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, cost_kwargs: CostKwargs_t):
Expand Down
15 changes: 11 additions & 4 deletions tests/problems/generic/test_sinkhorn_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from ott.geometry import epsilon_scheduler
from ott.geometry.costs import (
Cosine,
ElasticL1,
Expand Down Expand Up @@ -72,8 +73,8 @@ def test_solve_balanced(self, adata_time: AnnData):
("cosine", Cosine, {}),
("pnorm_p", PNormP, {"p": 3}),
("sq_pnorm", SqPNorm, {"p": 3}),
("elastic_l1", ElasticL1, {"gamma": 1.1}),
("elastic_stvs", ElasticSTVS, {"gamma": 1.2}),
("elastic_l1", ElasticL1, {"scaling_reg": 1.1}),
("elastic_stvs", ElasticSTVS, {"scaling_reg": 1.2}),
],
)
def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, cost_kwargs: Mapping[str, int]):
Expand Down Expand Up @@ -153,9 +154,15 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A

geom = lin_prob.geom
for arg, val in geometry_args.items():
assert hasattr(geom, val), val
assert hasattr(geom, val)
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
assert el == args_to_check[arg], arg
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert isinstance(eps_processed, epsilon_scheduler.Epsilon)
assert eps_processed.target == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]

args = pointcloud_args if args_to_check["rank"] == -1 else lr_pointcloud_args
for arg, val in args.items():
Expand Down
10 changes: 9 additions & 1 deletion tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import numpy as np
from ott.geometry import epsilon_scheduler

from anndata import AnnData

Expand Down Expand Up @@ -154,7 +155,14 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
geom = quad_prob.geom_xx
for arg, val in geometry_args.items():
assert hasattr(geom, val)
assert getattr(geom, val) == args_to_check[arg]
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert isinstance(eps_processed, epsilon_scheduler.Epsilon)
assert eps_processed.target == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]

geom = quad_prob.geom_xy
for arg, val in pointcloud_args.items():
Expand Down
10 changes: 9 additions & 1 deletion tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import numpy as np
from ott.geometry import epsilon_scheduler

from anndata import AnnData

Expand Down Expand Up @@ -159,7 +160,14 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str
geom = quad_prob.geom_xx
for arg, val in geometry_args.items():
assert hasattr(geom, val)
assert getattr(geom, val) == args_to_check[arg]
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert isinstance(eps_processed, epsilon_scheduler.Epsilon)
assert eps_processed.target == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]

geom = quad_prob.geom_xy
for arg, val in pointcloud_args.items():
Expand Down
10 changes: 9 additions & 1 deletion tests/problems/spatio_temporal/test_spatio_temporal_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from ott.geometry import epsilon_scheduler

from anndata import AnnData

Expand Down Expand Up @@ -213,7 +214,14 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map
geom = quad_prob.geom_xx
for arg, val in geometry_args.items():
assert hasattr(geom, val)
assert getattr(geom, val) == args_to_check[arg]
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert isinstance(eps_processed, epsilon_scheduler.Epsilon)
assert eps_processed.target == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]

geom = quad_prob.geom_xy
for arg, val in pointcloud_args.items():
Expand Down
Loading

0 comments on commit 4ffcf14

Please sign in to comment.