Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/Fix: WassersteinSolver constructor now throws TypeError when an unrecognized argument is given #579

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ jobs:
if: ${{ matrix.lint-kind == 'code' }}
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }}

key: pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary, as the cache key will be search on the PR's target branch if it's not on the feature branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this as a temporary solution as ci's failed. the reformatting was also because of ci for some reason. will undo this and the reformatting

restore-keys: |
pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}-
pre-commit-${{ runner.os }}-
pre-commit-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
8 changes: 8 additions & 0 deletions src/ott/solvers/was_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think there's a slightly better solution rather than inspecting the signature of the linear solvers,
I'd rather make linear_ot_solver a required argument and remove the construction of the solver in __init__ altogether - this will require some changes, esp. in tests, in ott/solvers/quadratic/_solve.py, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this means we can also remove kwargs right? I'd also prefer this to inspect but didn't want to change the interface in case you had other plans.

from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union

import jax
Expand Down Expand Up @@ -49,10 +50,12 @@ def __init__(
self.epsilon = epsilon if epsilon is not None else default_epsilon
self.rank = rank
self.linear_ot_solver = linear_ot_solver
used_kwargs = {}
if self.linear_ot_solver is None:
# Detect if user requests low-rank solver. In that case the
# default_epsilon makes little sense, since it was designed for GW.
if self.is_low_rank:
used_kwargs = dict(inspect.signature(sinkhorn_lr.LRSinkhorn).parameters)
if epsilon is None:
# Use default entropic regularization in LRSinkhorn if None was passed
self.linear_ot_solver = sinkhorn_lr.LRSinkhorn(
Expand All @@ -64,6 +67,7 @@ def __init__(
rank=self.rank, epsilon=self.epsilon, **kwargs
)
else:
used_kwargs = dict(inspect.signature(sinkhorn.Sinkhorn).parameters)
# When using Entropic GW, epsilon is not handled inside Sinkhorn,
# but rather added back to the Geometry object re-instantiated
# when linearizing the problem. Therefore, no need to pass it to solver.
Expand All @@ -73,6 +77,10 @@ def __init__(
self.max_iterations = max_iterations
self.threshold = threshold
self.store_inner_errors = store_inner_errors
# assert that all kwargs are valid
if not set(kwargs.keys()).issubset(used_kwargs.keys()):
unrecognized_kwargs = set(kwargs.keys()) - set(used_kwargs.keys())
raise TypeError(f"Invalid keyword arguments: {unrecognized_kwargs}.")
self._kwargs = kwargs

@property
Expand Down
107 changes: 89 additions & 18 deletions tests/solvers/quadratic/fgw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


class TestFusedGromovWasserstein:

# TODO(michalk8): refactor me in the future
@pytest.fixture(autouse=True)
def initialize(self, rng: jax.Array):
Expand Down Expand Up @@ -60,7 +59,12 @@ def test_gradient_marginals_fgw_solver(self, jit: bool):

def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool):
prob = quadratic_problem.QuadraticProblem(
geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why it was reformatted, but would prefer to undo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

geom_x,
geom_y,
geom_xy,
fused_penalty=self.fused_penalty,
a=a,
b=b,
)

implicit_diff = implicit_lib.ImplicitDiff() if implicit else None
Expand Down Expand Up @@ -96,16 +100,22 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool):
np.testing.assert_allclose(g_a, gi_a, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(g_b, gi_b, rtol=1e-2, atol=1e-2)

@pytest.mark.parametrize(("lse_mode", "is_cost"), [(True, False),
(False, True)],
ids=["lse-pc", "kernel-cost-mat"])
@pytest.mark.parametrize(
("lse_mode", "is_cost"),
[(True, False), (False, True)],
ids=["lse-pc", "kernel-cost-mat"],
)
def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool):
"""Test gradient w.r.t. the geometries."""

def reg_gw(
x: jnp.ndarray, y: jnp.ndarray,
x: jnp.ndarray,
y: jnp.ndarray,
xy: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool
fused_penalty: float,
a: jnp.ndarray,
b: jnp.ndarray,
implicit: bool,
):
if is_cost:
geom_x = geometry.Geometry(cost_matrix=x)
Expand All @@ -121,7 +131,9 @@ def reg_gw(

implicit_diff = implicit_lib.ImplicitDiff() if implicit else None
linear_solver = sinkhorn.Sinkhorn(
lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=1000
lse_mode=lse_mode,
implicit_diff=implicit_diff,
max_iterations=1000,
)
solver = gromov_wasserstein.GromovWasserstein(
linear_ot_solver=linear_solver, epsilon=1.0, max_iterations=10
Expand Down Expand Up @@ -168,7 +180,7 @@ def loss_thre(threshold: float) -> float:
geom_xy,
a=self.a,
b=self.b,
fused_penalty=self.fused_penalty_2
fused_penalty=self.fused_penalty_2,
)
solver = gromov_wasserstein.GromovWasserstein(
threshold=threshold, epsilon=1e-1
Expand All @@ -184,8 +196,13 @@ def test_gradient_fgw_solver_penalty(self):
lse_mode = True

def reg_gw(
cx: jnp.ndarray, cy: jnp.ndarray, cxy: jnp.ndarray,
fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool
cx: jnp.ndarray,
cy: jnp.ndarray,
cxy: jnp.ndarray,
fused_penalty: float,
a: jnp.ndarray,
b: jnp.ndarray,
implicit: bool,
) -> float:
geom_x = geometry.Geometry(cost_matrix=cx)
geom_y = geometry.Geometry(cost_matrix=cy)
Expand All @@ -196,7 +213,9 @@ def reg_gw(

implicit_diff = implicit_lib.ImplicitDiff() if implicit else None
linear_solver = sinkhorn.Sinkhorn(
lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=200
lse_mode=lse_mode,
implicit_diff=implicit_diff,
max_iterations=200,
)
solver = gromov_wasserstein.GromovWasserstein(
epsilon=1.0, max_iterations=10, linear_ot_solver=linear_solver
Expand All @@ -207,8 +226,13 @@ def reg_gw(
for i, implicit in enumerate([True, False]):
reg_fgw_grad = jax.grad(reg_gw, argnums=(3,))
grad_matrices[i] = reg_fgw_grad(
self.cx, self.cy, self.cxy, self.fused_penalty, self.a, self.b,
implicit
self.cx,
self.cy,
self.cxy,
self.fused_penalty,
self.a,
self.b,
implicit,
)
assert not jnp.any(jnp.isnan(grad_matrices[i][0]))

Expand Down Expand Up @@ -272,7 +296,7 @@ def test_fgw_lr_generic_cost_matrix(
epsilon=10.0,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
max_iterations=2000,
)
out = solver(prob)

Expand Down Expand Up @@ -314,7 +338,7 @@ def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]):
geom_y,
geom_xy,
fused_penalty=fused_penalty,
scale_cost=scale_cost
scale_cost=scale_cost,
)
solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)

Expand Down Expand Up @@ -344,14 +368,14 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float):
geom_yy,
geom_xy=geom_xy,
fused_penalty=fused_penalty,
store_inner_errors=True
store_inner_errors=True,
)
out_fp = quadratic.solve(
geom_xx,
geom_yy,
geom_xy=geom_xy_fp,
fused_penalty=1.0,
store_inner_errors=True
store_inner_errors=True,
)

np.testing.assert_allclose(out.costs, out_fp.costs, rtol=rtol, atol=atol)
Expand All @@ -362,3 +386,50 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float):
np.testing.assert_allclose(
out.reg_gw_cost, out_fp.reg_gw_cost, rtol=rtol, atol=atol
)

@pytest.mark.parametrize(
selmanozleyen marked this conversation as resolved.
Show resolved Hide resolved
(
"fused",
"lr",
),
[
(
True,
False,
),
(
False,
True,
),
(
True,
True,
),
(
False,
False,
),
],
)
def test_solver_unrecognized_args_fails(self, fused: bool, lr: bool):
fused_penalty = 1.0 if fused else 0.0
epsilon = 5.0
geom_x = pointcloud.PointCloud(self.x)
geom_y = pointcloud.PointCloud(self.y)
geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) if fused else None

prob = quadratic_problem.QuadraticProblem(
geom_xx=geom_x,
geom_yy=geom_y,
geom_xy=geom_xy,
fused_penalty=fused_penalty,
)
if lr:
prob = prob.to_low_rank()

solver_cls = (
gromov_wasserstein_lr.LRGromovWasserstein
if lr else gromov_wasserstein.GromovWasserstein
)
with pytest.raises(TypeError):
solver_cls(epsilon=epsilon, dummy=42)(prob)
Loading