Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/Fix: WassersteinSolver constructor now throws TypeError when an unrecognized argument is given #579

Conversation

selmanozleyen
Copy link
Contributor

@selmanozleyen selmanozleyen commented Sep 23, 2024

hi,

A user can give an argument by typo or any other misunderstanding and the solver class would work without them noticing. To prevent such cases I made some modifications. I also added tests that asserts that the raises are thrown properly.

Note: I am not sure about why the linting fails, it tox -e lint-code passes locally for me. Note: I also modified the caching in CI's because it didn't work on my pr for some reason

Related: theislab/moscot#748

ping: @MUCDK

Copy link

codecov bot commented Sep 23, 2024

Codecov Report

Attention: Patch coverage is 57.14286% with 3 lines in your changes missing coverage. Please review.

Project coverage is 87.81%. Comparing base (aa33bd9) to head (88bde47).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/solvers/was_solver.py 57.14% 2 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #579      +/-   ##
==========================================
- Coverage   87.83%   87.81%   -0.03%     
==========================================
  Files          73       73              
  Lines        7826     7845      +19     
  Branches     1127     1133       +6     
==========================================
+ Hits         6874     6889      +15     
- Misses        799      801       +2     
- Partials      153      155       +2     
Files with missing lines Coverage Δ
src/ott/solvers/was_solver.py 79.59% <57.14%> (-3.75%) ⬇️

... and 4 files with indirect coverage changes

@marcocuturi
Copy link
Contributor

thanks @selmanozleyen for the PR! i will defer to @michalk8 on this, but it feels that if we implement this for this particular solver, we would need to implement it for all solvers, no? What was the use case that revealed the problem?

@marcocuturi marcocuturi added the enhancement New feature or request label Sep 24, 2024
@selmanozleyen
Copy link
Contributor Author

For linear solvers there is no need as their base class Sinkhorn doesn't take kwargs. Since WassersteinSolver now handles unrecognized kwargs, all it's child classes will also handle it (since all child classes pass remaining kwargs to super()__init__()).

In moscot we don't want to ignore any unrecognized arguments since there are many arguments, and with some typo etc. it can lead to some well hidden bugs.

Here is the PR for it:theislab/moscot#748

We use many (if not all) solvers in our case and from my tests this PR should be enough to cover the constructors for linear and quadratic solvers. I am not sure about other methods such as solve in ottjax though.

Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@selmanozleyen in the comment above, I think we should rather explicitly pass the linear_ot_solver to WassersteinSolver instead.
Lmk if you prefer to do it or not. I can also take a look at it, as there might be many places in tests/ that need a change.

@@ -33,8 +33,11 @@ jobs:
if: ${{ matrix.lint-kind == 'code' }}
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }}

key: pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary, as the cache key will be search on the PR's target branch if it's not on the feature branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this as a temporary solution as ci's failed. the reformatting was also because of ci for some reason. will undo this and the reformatting

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think there's a slightly better solution rather than inspecting the signature of the linear solvers,
I'd rather make linear_ot_solver a required argument and remove the construction of the solver in __init__ altogether - this will require some changes, esp. in tests, in ott/solvers/quadratic/_solve.py, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this means we can also remove kwargs right? I'd also prefer this to inspect but didn't want to change the interface in case you had other plans.

@@ -60,7 +59,12 @@ def test_gradient_marginals_fgw_solver(self, jit: bool):

def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool):
prob = quadratic_problem.QuadraticProblem(
geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why it was reformatted, but would prefer to undo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

tests/solvers/quadratic/fgw_test.py Outdated Show resolved Hide resolved
@selmanozleyen
Copy link
Contributor Author

@michalk8 since the interface is going to change I think it would be better if you did it. I already resolved other pre-commit and formatting issues you mentioned

@michalk8
Copy link
Collaborator

@michalk8 since the interface is going to change I think it would be better if you did it. I already resolved other pre-commit and formatting issues you mentioned

Ok, thanks! I will then close this PR and open tomorrow a new one.

@michalk8 michalk8 closed this Sep 25, 2024
@selmanozleyen
Copy link
Contributor Author

hi @michalk8, just wanted to remind you on this. I think many test cases and stuff might have to change since the API also changes. So maybe I can help a bit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants