Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No convergence with new commit of ottjax in LRGW solvers unless inner_iterations=10 is set #678

Open
Tracked by #677
selmanozleyen opened this issue Mar 20, 2024 · 3 comments · May be fixed by #748
Open
Tracked by #677
Assignees

Comments

@selmanozleyen
Copy link
Collaborator

selmanozleyen commented Mar 20, 2024

I am working on this and I am solving problems with the new ottjax version. You can also see the tests that fail from the CI https://github.com/theislab/moscot/actions/runs/8361086332/job/22888306575#step:5:1934.

The code that fails would look something like this:

ap = (
    AlignmentProblem(adata=adata_space_rotate)
    .prepare(batch_key="batch")
    .solve(epsilon=epsilon, alpha=alpha, rank=rank, initializer=initializer)
)

for prob_key in ap:
    assert ap[prob_key].solution.rank == rank
    assert ap[prob_key].solution.converged

but I noticed if I set inner_iterations=10 it converges

ap = (
    AlignmentProblem(adata=adata_space_rotate)
    .prepare(batch_key="batch")
    .solve(epsilon=epsilon, alpha=alpha, rank=rank, initializer=initializer, inner_iterations=10)
)

for prob_key in ap:
    assert ap[prob_key].solution.rank == rank
    assert ap[prob_key].solution.converged

The new versions of LRGW solvers in ottjax handle inner_iterations differently. This is one example test that fails

@selmanozleyen selmanozleyen self-assigned this Mar 20, 2024
@giovp
Copy link
Member

giovp commented Mar 20, 2024

hi @selmanozleyen , yes, you probably noticed this ott-jax/ott@036add2 , we might have to update the default to mirror the ott-jax default

@selmanozleyen
Copy link
Collaborator Author

I think it mirrors the default but the problem is it doesn't converge when it's a higher number like 2000. It does when it is 10. Idk if this is expected

@selmanozleyen
Copy link
Collaborator Author

Ok, I will set the inner iterations to 10 as in the tests of ott-jax. Defaults are already passing.

selmanozleyen added a commit that referenced this issue Apr 30, 2024
selmanozleyen added a commit that referenced this issue Apr 30, 2024
giovp pushed a commit that referenced this issue Apr 30, 2024
* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* recreate solution files with new ottjax version so it doesn't fail

* see #678

* skip failing test

* skip other tests #678

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@selmanozleyen selmanozleyen reopened this May 14, 2024
@selmanozleyen selmanozleyen linked a pull request Sep 22, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants