Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Always use special .solve for Kronecker linear operators #50

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

saitcakmak
Copy link
Collaborator

As titled. These linear operators are generally much larger than their components. If fast_computations (in particular _fast_solves) is turned off, then we try to compute Cholesky over huge matrices, which leads to OOMs.

Comment on lines +20 to +31
if isinstance(
linear_op,
(
CholLinearOperator,
TriangularLinearOperator,
KroneckerProductAddedDiagLinearOperator,
KroneckerProductLinearOperator,
KroneckerProductDiagLinearOperator,
KroneckerProductTriangularLinearOperator,
SumKroneckerLinearOperator,
),
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm will this always apply the special solve method? There may be situations in which we want to use Linear CG solves even for some operators with a special solve method.

Aside: The name "fast_computations" is a bit weird; whether it's fast or not will depend on the operator structure and the data size...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gpleiss, @jacobrgardner, curious about your thoughts here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: The name "fast_computations" is a bit weird; whether it's fast or not will depend on the operator structure and the data size...

Agreed. I regret it.

There may be situations in which we want to use Linear CG solves even for some operators with a special solve method.

@JonathanWenger and I brainstormed this a bit. One thought that we had was that a user could specify (via context manager, inline argument, etc.) when they want to go into iterative solving mode. All other solves would be performed using direct methods otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that could be useful. Another option would be to attach default rules for the decision which solves to use to the respective operators but then allow to override them (either way so a default exact solve may use an iterative instead and vice versa).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @Balandat on the default rules. That would nicely integrate with using a Kronecker, or banded specific solver. The interface Geoff and I were discussing was either via a context manager or with an optional argument that could specify the default per linear operator:

def solve(self, right_tensor: torch.Tensor, left_tensor: Optional[torch.Tensor] = None, linear_solver: LinearSolver = CG()) -> torch.Tensor:

However, there were some potential issues with this interface and the interplay with implementing torch.linalg.solve, if I remember correctly. In a perfect world torch.linalg.solve would dispatch on the specific kind of LinearOperator I suppose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants