Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/Fix: WassersteinSolver constructor now throws TypeError when an unrecognized argument is given #579

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ott/solvers/was_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think there's a slightly better solution rather than inspecting the signature of the linear solvers,
I'd rather make linear_ot_solver a required argument and remove the construction of the solver in __init__ altogether - this will require some changes, esp. in tests, in ott/solvers/quadratic/_solve.py, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this means we can also remove kwargs right? I'd also prefer this to inspect but didn't want to change the interface in case you had other plans.

from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union

import jax
Expand Down Expand Up @@ -49,10 +50,12 @@ def __init__(
self.epsilon = epsilon if epsilon is not None else default_epsilon
self.rank = rank
self.linear_ot_solver = linear_ot_solver
used_kwargs = {}
if self.linear_ot_solver is None:
# Detect if user requests low-rank solver. In that case the
# default_epsilon makes little sense, since it was designed for GW.
if self.is_low_rank:
used_kwargs = dict(inspect.signature(sinkhorn_lr.LRSinkhorn).parameters)
if epsilon is None:
# Use default entropic regularization in LRSinkhorn if None was passed
self.linear_ot_solver = sinkhorn_lr.LRSinkhorn(
Expand All @@ -64,6 +67,7 @@ def __init__(
rank=self.rank, epsilon=self.epsilon, **kwargs
)
else:
used_kwargs = dict(inspect.signature(sinkhorn.Sinkhorn).parameters)
# When using Entropic GW, epsilon is not handled inside Sinkhorn,
# but rather added back to the Geometry object re-instantiated
# when linearizing the problem. Therefore, no need to pass it to solver.
Expand All @@ -73,6 +77,10 @@ def __init__(
self.max_iterations = max_iterations
self.threshold = threshold
self.store_inner_errors = store_inner_errors
# assert that all kwargs are valid
if not set(kwargs.keys()).issubset(used_kwargs.keys()):
unrecognized_kwargs = set(kwargs.keys()) - set(used_kwargs.keys())
raise TypeError(f"Invalid keyword arguments: {unrecognized_kwargs}.")
self._kwargs = kwargs

@property
Expand Down
25 changes: 25 additions & 0 deletions tests/solvers/quadratic/fgw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,28 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float):
np.testing.assert_allclose(
out.reg_gw_cost, out_fp.reg_gw_cost, rtol=rtol, atol=atol
)

@pytest.mark.parametrize(("fused", "lr"), [(True, False), (False, True),
(True, True), (False, False)])
def test_solver_unrecognized_args_fails(self, fused: bool, lr: bool):
fused_penalty = 1.0 if fused else 0.0
epsilon = 5.0
geom_x = pointcloud.PointCloud(self.x)
geom_y = pointcloud.PointCloud(self.y)
geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) if fused else None

prob = quadratic_problem.QuadraticProblem(
geom_xx=geom_x,
geom_yy=geom_y,
geom_xy=geom_xy,
fused_penalty=fused_penalty,
)
if lr:
prob = prob.to_low_rank()

solver_cls = (
gromov_wasserstein_lr.LRGromovWasserstein
if lr else gromov_wasserstein.GromovWasserstein
)
with pytest.raises(TypeError):
solver_cls(epsilon=epsilon, dummy=42)(prob)