-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/Fix: WassersteinSolver constructor now throws TypeError when an unrecognized argument is given #579
Changes from 3 commits
201a806
73e6f54
bd6672c
62b7582
a3b30c8
88bde47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import inspect | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, I think there's a slightly better solution rather than inspecting the signature of the linear solvers, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then this means we can also remove |
||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union | ||
|
||
import jax | ||
|
@@ -49,10 +50,12 @@ def __init__( | |
self.epsilon = epsilon if epsilon is not None else default_epsilon | ||
self.rank = rank | ||
self.linear_ot_solver = linear_ot_solver | ||
used_kwargs = {} | ||
if self.linear_ot_solver is None: | ||
# Detect if user requests low-rank solver. In that case the | ||
# default_epsilon makes little sense, since it was designed for GW. | ||
if self.is_low_rank: | ||
used_kwargs = dict(inspect.signature(sinkhorn_lr.LRSinkhorn).parameters) | ||
if epsilon is None: | ||
# Use default entropic regularization in LRSinkhorn if None was passed | ||
self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( | ||
|
@@ -64,6 +67,7 @@ def __init__( | |
rank=self.rank, epsilon=self.epsilon, **kwargs | ||
) | ||
else: | ||
used_kwargs = dict(inspect.signature(sinkhorn.Sinkhorn).parameters) | ||
# When using Entropic GW, epsilon is not handled inside Sinkhorn, | ||
# but rather added back to the Geometry object re-instantiated | ||
# when linearizing the problem. Therefore, no need to pass it to solver. | ||
|
@@ -73,6 +77,10 @@ def __init__( | |
self.max_iterations = max_iterations | ||
self.threshold = threshold | ||
self.store_inner_errors = store_inner_errors | ||
# assert that all kwargs are valid | ||
if not set(kwargs.keys()).issubset(used_kwargs.keys()): | ||
unrecognized_kwargs = set(kwargs.keys()) - set(used_kwargs.keys()) | ||
raise TypeError(f"Invalid keyword arguments: {unrecognized_kwargs}.") | ||
self._kwargs = kwargs | ||
|
||
@property | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,6 @@ | |
|
||
|
||
class TestFusedGromovWasserstein: | ||
|
||
# TODO(michalk8): refactor me in the future | ||
@pytest.fixture(autouse=True) | ||
def initialize(self, rng: jax.Array): | ||
|
@@ -60,7 +59,12 @@ def test_gradient_marginals_fgw_solver(self, jit: bool): | |
|
||
def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): | ||
prob = quadratic_problem.QuadraticProblem( | ||
geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why it was reformatted, but would prefer to undo. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done! |
||
geom_x, | ||
geom_y, | ||
geom_xy, | ||
fused_penalty=self.fused_penalty, | ||
a=a, | ||
b=b, | ||
) | ||
|
||
implicit_diff = implicit_lib.ImplicitDiff() if implicit else None | ||
|
@@ -96,16 +100,22 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): | |
np.testing.assert_allclose(g_a, gi_a, rtol=1e-2, atol=1e-2) | ||
np.testing.assert_allclose(g_b, gi_b, rtol=1e-2, atol=1e-2) | ||
|
||
@pytest.mark.parametrize(("lse_mode", "is_cost"), [(True, False), | ||
(False, True)], | ||
ids=["lse-pc", "kernel-cost-mat"]) | ||
@pytest.mark.parametrize( | ||
("lse_mode", "is_cost"), | ||
[(True, False), (False, True)], | ||
ids=["lse-pc", "kernel-cost-mat"], | ||
) | ||
def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool): | ||
"""Test gradient w.r.t. the geometries.""" | ||
|
||
def reg_gw( | ||
x: jnp.ndarray, y: jnp.ndarray, | ||
x: jnp.ndarray, | ||
y: jnp.ndarray, | ||
xy: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], | ||
fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool | ||
fused_penalty: float, | ||
a: jnp.ndarray, | ||
b: jnp.ndarray, | ||
implicit: bool, | ||
): | ||
if is_cost: | ||
geom_x = geometry.Geometry(cost_matrix=x) | ||
|
@@ -121,7 +131,9 @@ def reg_gw( | |
|
||
implicit_diff = implicit_lib.ImplicitDiff() if implicit else None | ||
linear_solver = sinkhorn.Sinkhorn( | ||
lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=1000 | ||
lse_mode=lse_mode, | ||
implicit_diff=implicit_diff, | ||
max_iterations=1000, | ||
) | ||
solver = gromov_wasserstein.GromovWasserstein( | ||
linear_ot_solver=linear_solver, epsilon=1.0, max_iterations=10 | ||
|
@@ -168,7 +180,7 @@ def loss_thre(threshold: float) -> float: | |
geom_xy, | ||
a=self.a, | ||
b=self.b, | ||
fused_penalty=self.fused_penalty_2 | ||
fused_penalty=self.fused_penalty_2, | ||
) | ||
solver = gromov_wasserstein.GromovWasserstein( | ||
threshold=threshold, epsilon=1e-1 | ||
|
@@ -184,8 +196,13 @@ def test_gradient_fgw_solver_penalty(self): | |
lse_mode = True | ||
|
||
def reg_gw( | ||
cx: jnp.ndarray, cy: jnp.ndarray, cxy: jnp.ndarray, | ||
fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool | ||
cx: jnp.ndarray, | ||
cy: jnp.ndarray, | ||
cxy: jnp.ndarray, | ||
fused_penalty: float, | ||
a: jnp.ndarray, | ||
b: jnp.ndarray, | ||
implicit: bool, | ||
) -> float: | ||
geom_x = geometry.Geometry(cost_matrix=cx) | ||
geom_y = geometry.Geometry(cost_matrix=cy) | ||
|
@@ -196,7 +213,9 @@ def reg_gw( | |
|
||
implicit_diff = implicit_lib.ImplicitDiff() if implicit else None | ||
linear_solver = sinkhorn.Sinkhorn( | ||
lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=200 | ||
lse_mode=lse_mode, | ||
implicit_diff=implicit_diff, | ||
max_iterations=200, | ||
) | ||
solver = gromov_wasserstein.GromovWasserstein( | ||
epsilon=1.0, max_iterations=10, linear_ot_solver=linear_solver | ||
|
@@ -207,8 +226,13 @@ def reg_gw( | |
for i, implicit in enumerate([True, False]): | ||
reg_fgw_grad = jax.grad(reg_gw, argnums=(3,)) | ||
grad_matrices[i] = reg_fgw_grad( | ||
self.cx, self.cy, self.cxy, self.fused_penalty, self.a, self.b, | ||
implicit | ||
self.cx, | ||
self.cy, | ||
self.cxy, | ||
self.fused_penalty, | ||
self.a, | ||
self.b, | ||
implicit, | ||
) | ||
assert not jnp.any(jnp.isnan(grad_matrices[i][0])) | ||
|
||
|
@@ -272,7 +296,7 @@ def test_fgw_lr_generic_cost_matrix( | |
epsilon=10.0, | ||
min_iterations=0, | ||
inner_iterations=10, | ||
max_iterations=2000 | ||
max_iterations=2000, | ||
) | ||
out = solver(prob) | ||
|
||
|
@@ -314,7 +338,7 @@ def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]): | |
geom_y, | ||
geom_xy, | ||
fused_penalty=fused_penalty, | ||
scale_cost=scale_cost | ||
scale_cost=scale_cost, | ||
) | ||
solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon) | ||
|
||
|
@@ -344,14 +368,14 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): | |
geom_yy, | ||
geom_xy=geom_xy, | ||
fused_penalty=fused_penalty, | ||
store_inner_errors=True | ||
store_inner_errors=True, | ||
) | ||
out_fp = quadratic.solve( | ||
geom_xx, | ||
geom_yy, | ||
geom_xy=geom_xy_fp, | ||
fused_penalty=1.0, | ||
store_inner_errors=True | ||
store_inner_errors=True, | ||
) | ||
|
||
np.testing.assert_allclose(out.costs, out_fp.costs, rtol=rtol, atol=atol) | ||
|
@@ -362,3 +386,50 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): | |
np.testing.assert_allclose( | ||
out.reg_gw_cost, out_fp.reg_gw_cost, rtol=rtol, atol=atol | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
selmanozleyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
( | ||
"fused", | ||
"lr", | ||
), | ||
[ | ||
( | ||
True, | ||
False, | ||
), | ||
( | ||
False, | ||
True, | ||
), | ||
( | ||
True, | ||
True, | ||
), | ||
( | ||
False, | ||
False, | ||
), | ||
], | ||
) | ||
def test_solver_unrecognized_args_fails(self, fused: bool, lr: bool): | ||
fused_penalty = 1.0 if fused else 0.0 | ||
epsilon = 5.0 | ||
geom_x = pointcloud.PointCloud(self.x) | ||
geom_y = pointcloud.PointCloud(self.y) | ||
geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) if fused else None | ||
|
||
prob = quadratic_problem.QuadraticProblem( | ||
geom_xx=geom_x, | ||
geom_yy=geom_y, | ||
geom_xy=geom_xy, | ||
fused_penalty=fused_penalty, | ||
) | ||
if lr: | ||
prob = prob.to_low_rank() | ||
|
||
solver_cls = ( | ||
gromov_wasserstein_lr.LRGromovWasserstein | ||
if lr else gromov_wasserstein.GromovWasserstein | ||
) | ||
with pytest.raises(TypeError): | ||
solver_cls(epsilon=epsilon, dummy=42)(prob) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary, as the cache key will be search on the PR's target branch if it's not on the feature branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did this as a temporary solution as ci's failed. the reformatting was also because of ci for some reason. will undo this and the reformatting