Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Analysis] Improve handling of symbolic variables #15627

Conversation

Lunderberg
Copy link
Contributor

Previously, the struct info analysis returned kFailL2 (compatible, depending on runtime-inferred symbolic variables) based on each match expression considered in isolation. This ignored cases such as matching a square tensor [n,n] against a static shape [16,32]. While no one dimension is incompatible on its own, the match requires that (n==16) && (n==32). This this can be statically proven to be false, the StructInfoBaseChecker should return kFailL0 (statically proven to be incompatible) instead.

This commit updates the StructInfoBaseChecker to track implied requirements for symbolic variables across multiple matched dimensions.

The expression `(x==y) && (x==z)` requires that `y==z`.  When `y` and
`z` are constants, this can allow better constant folding by
rewriting `(x==c1) && (x==c2)` into `(x==c1) && (c1==c2)`.

This commit adds the above rewrite, and the corresponding rewrite of
the negative expression.
Previously, the struct info analysis returned `kFailL2` (compatible,
depending on runtime-inferred symbolic variables) based on each match
expression considered in isolation.  This ignored cases such as
matching a square tensor `[n,n]` against a static shape `[16,32]`.
While no one dimension is incompatible on its own, the match requires
that `(n==16) && (n==32)`.  This this can be statically proven to be
false, the `StructInfoBaseChecker` should return `kFailL0` (statically
proven to be incompatible) instead.

This commit updates the `StructInfoBaseChecker` to track implied
requirements for symbolic variables across multiple matched
dimensions.
@tvm-bot
Copy link
Collaborator

tvm-bot commented Aug 27, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@Lunderberg Lunderberg changed the title [Relax][Analysis] Improve handling of symbolic variables [Unity][Analysis] Improve handling of symbolic variables Aug 27, 2023
@tqchen
Copy link
Member

tqchen commented Aug 29, 2023

Thanks @Lunderberg . One thing that we should consider is the overall overhead of simplifier and code simplicity. Ideally the deduction should be a fast path and we can afford to fallback to runtime check in this case.

This is also a more corner case that happens less often

@Lunderberg
Copy link
Contributor Author

Thanks @Lunderberg . One thing that we should consider is the overall overhead of simplifier and code simplicity. Ideally the deduction should be a fast path

Thank you for the feedback, @tqchen. If I remember correctly from some analyzer performance tests I did last year, the simplifications tend to be low overhead when the expression is already in a simplified form. The performance overhead primarily occurs when an expression requires multiple rounds of recursive rewriting, and the resulting allocations. As a result, I think this would generally improve code simplicity, as consistently using the analyzer for comparisons would avoid needing the if(value->IsInstance<IntImmNode>()) checks in many locations.

We probably could further improve the code simplicity by adding a helper method Optional<Bool> Analyzer::TryProve(...) method, as that would avoid repeated simplification in CanProve(expr) followed by CanProve(!expr), but that would be best in a later PR.

and we can afford to fallback to runtime check in this case.

Performance-wise, I agree. For debugging purposes, having the error message be delayed until runtime makes it difficult to determine which lowering/optimization pass introduced incompatible shapes.

This is also a more corner case that happens less often

In the case of a single tensor, I agree. In the long term, this is to improve the symbolic variable handling of FuncStructInfo within IsBaseOf checks. For FuncStructInfo, where many model parameters may use the same symbolic variables (e.g. seq_len, hidden_size, etc), this would ensure that the model parameters are compatible with each other.

@tqchen
Copy link
Member

tqchen commented Sep 6, 2023

Thanks @Lunderberg , let me elaborate a bit since this is indeed a design tradeoff here. There are a few factors that goes into this

  • (a) overall readability and predictability of the code(when we send in a condition we know that it will work)
  • (b) the overall power of the prove.

If we consider (b) indeed we can tend to bring more into arith simplifier. However that can bring extra complexity both in terms of how we reason about the behavior (since arith is really only best effort). While the code can be more concise, it is also more complicated in terms of understanding the code behavior (mainly due to the complexity of simplifier itself).

As a result, in core cases like struct info derivation, we would favor predictability over perhaps lines of code. When we use simplifier, there are also several levels of complexity

  • L0: simple mapping tasks, e.g. map n+2 => 3 when n=1
  • L1: equivalence relation propagation, use equivalence relation in one to propagate to another
  • L2: more complicated proves.

In this case we would try to limit cases to L0 for predictability reasons. So here are some rationales for the base function structinfo deriver:

  • Useif(value->IsInstance<IntImmNode>()) when possible to ensure static shape case are always handled, without going through simplifier. This ensures static shape cases are always handled predictably.
  • Only use L0 of simplifier when possible.

Coming back to the example (try to detect [2, 32] to [n, n] matching error). I think we can have an alternative approach to specifically handle this case predictably.

In this case, we should be able to do canprove not equal in this case https://github.com/apache/tvm/blob/unity/src/relax/analysis/struct_info_analysis.cc#L663 and report error

Note that in this case it is intentionally happen in a subclass CallRetDeriver. This is because we would like to keep the base class simple and predictable(even if it might be slightly weaker).

@Lunderberg
Copy link
Contributor Author

Thank you for the detailed response, @tqchen.

I think my main concern is that there are several types of predictability, depending on the specific audience.

  1. As a developer writing a lowering pass, checks that are written immediately within the lowering pass do not require checking an external utility. Predictability is improved by writing the checks in-line using IsInstance<IntImmNode>(), because it does not require looking outside the lowering pass.

  2. As a developer debugging of a lowering pass, checks that are written immediately within the lowering pass require determining the purpose of the check. Mixing layers of abstraction between what must be proven (e.g. two shapes must be equal) and how that proof may be performed (e.g. check if they are both a static integer, then check if they are the same) makes this more difficult. Predictability is improved by separating the expression to be proven and the utility to perform that proof, because it avoids mixing across different layers of abstraction.

  3. As a developer calling into a lowering pass, I would be aware of the general utilities provided by TVM, but not the specific implementation of each lowering pass. If each pass uses the same general utility, I can predict that cases handled by that utility will be handled correctly by the lowering pass. If each pass implements its own in-line implementation of a portion of the general utility, then I cannot predict which cases are handled by each separate implementation. Predictability is improved by using a general utility, because it provides the same level of support across each usage.

I agree that for core cases we should favor predictability, but would argue that we should use the general utility in order to provide predictability. Having less accurate checks implemented in many places improves predictability (1), but significantly worsens predictability (2) and (3).

For context, this PR originated due to an unexpected failure of an assert while debugging. My expectation was that the StructInfoBaseCheck would track symbolic var usage within the StructInfo, and my prediction was incorrect. To improve predictability, this PR aimed to close the gap between expected and actual results.

Note that in this case it is intentionally happen in a subclass CallRetDeriver. This is because we would like to keep the base class simple and predictable(even if it might be slightly weaker).

Would this still provide predictability over long-lived codebases? If some core functionality (e.g. equality checking) exists in two flavors, one which performs a stronger check and one which performs a weaker check, then any given point of use may pick one either flavor. A user may know that a specific functionality must be part of an implementation, but wouldn't know which flavor of that functionality was selected. Predictability would be improved by only having the stronger check available, because it avoids the question of which check had been selected.

@tqchen
Copy link
Member

tqchen commented Sep 7, 2023

The main thing here is boundary of the overall interface. It is hard to setup expectation
by saying here is the implementation we use. Instead we would like to desribe the expectation
about what each pass can do. There is indeed a tradeoff for such description.
When looking for such expectation, we need simple language that describes what we can handle.

Our rationale as of now is to setup clear expectation about what we can handle robustly,
while also leaving room for what are best effort:

Expectation(E0): the expectation here is that we would like to handle static mismatch well, because that is we know what we are capable of and we can handle all of them, but also being honest that anything beyond static mismatch be best effort (aka we can handle some, but cannot ensure that we handle all possible cases).

We pick E0 because it is very clear and predictable. We know stronger checks are available (at best effort), but it is harder to describe another set that goes beyond E0 in simple language. So we say that stronger checks are available, but are best effort.

After we set the expectations, we can then comeback and ask which implementation to pick.
There are multiple approaches to simplify: (a) integer remapping; (b) rewrite simplification (c) canonical simplification (d) affine rewrites. Each comes with their own readability, simplicity, cost of execution. etc. None of them are perfect. We know that there is no perfect ways to run proves both efficiently for all cases. This is also why we would need to let each pass/location choose their strategy with E0 in mind.

The base checker's goal is to provide a simple and efficient implementation that meets the E0, while open doors for better effort derivations also in a mindful way. It's behavior might be weaker, but we know what it can handle reliably. Because of the tradeoffs in simplifications, in the end we would need to think about more of the decoupling through things like E0.

This being said, it is still valuable to handle more cases. The good news is that we don't need to worry
about handling all because of E0, and can focuse on important ones rise from real world needs.
For example [n, n] <=> [2, 3] match is a great real world scenario that benefit from related checks, a quick update on https://github.com/apache/tvm/blob/unity/src/relax/analysis/struct_info_analysis.cc#L663 should be able to resolve that case well

@Lunderberg
Copy link
Contributor Author

This is also why we would need to let each pass/location choose their strategy with E0 in mind.

This might be my source of confusion here. By having each pass/location choose their own strategy, it means that each pass/location is implicitly defining which inputs a caller is allowed to pass in. By re-using a common utility, we minimize the extent of these restrictions. By using common utilities widely across a codebase, and by taking each failure of expectations as opportunities to improve the robustness of those common utilities, we make the codebase as a whole more predictable.

For example [n, n] <=> [2, 3] match is a great real world scenario that benefit from related checks, a quick update on https://github.com/apache/tvm/blob/unity/src/relax/analysis/struct_info_analysis.cc#L663 should be able to resolve that case well

I can update the current PR to track known values about symbolic variables, to avoid being blocked on it, but I'd like to avoid having a proliferation of duplicated functionality across the codebase.

@tqchen
Copy link
Member

tqchen commented Sep 8, 2023

It really depends on whether there is a limitation being offered by certain utilities. e.g. quick int remapping is faster and works well on many cases, in some other cases arith simplification is necessary.

Unfortunately, when it comes to symbolic proves, there is no silver bullet approach to use one fit all. Since in many cases even the arith itself is not sufficient, as we need to leverage things like affine maps or other forms of simplification.

Acknowledging that there is a tradeoff here, the goal of E0 is to decouple the convention from the implementation. While arith can be more powerful, we know that it is not perfect and subject to change, trying to enforcing expectation around the implementation is a bit of moving target. So instead, we would first need to clarify the goal (E0) independent from implementation.

implicitly defining which inputs a caller is allowed

The goal is to avoid the behavior being defined by the impl(e.g. the choice of arith). Instead, E0 states that the pass would require to handle static case well, and enable best effort otherwise (since it is hard to clarify another form of wording due to the hard nature of arbitrary symbolic proves).

From the users and implemenations' pov, they only know that E0 is being guaranteed. We know in dyn shape sometimes we can do more, but those are "best effort" atm, since we cannot come up with a better consistent term.

Any implementation that meets E0 should suffice, and we build optimizations on that basis. Defining E0 clearly would help us to decouple the implementations. Of course in some cases we might desire stronger guarantees, in that case, it would be useful to think about how can we update E0 to articulate them clearly.

For example, we can say that function return value deriver should be able to handle the symbolic equivalent case when a symbol get assigned to multiple static shapes. Which is a clear way of articulate what it can be confidently.

@Lunderberg
Copy link
Contributor Author

Hmm, good point. Perhaps that points to the arith::Analyzer having too much functionality that cannot be enabled/disabled independently. If an analyzer had explicit enabling for each type of simplification, then a pass could explicitly declare the functionality required for the minimal E0 handling, while inheriting any additional simplifications that the user implements.

Noodling about as to a possible interface:

HypotheticalSimplifier simplifer;

// Does nothing, no functionality enabled
simplifier.simplify(expr);

// Simplifies with constant folding, but nothing else.  Sufficient for
// handling static shapes.
auto enable_context = simplifier.enable(Functionality::ConstantFolding);
simplifier.simplify(expr);|

{
    // Temporarily enable additional functionality, on top of previous.
    auto enable_context = simplifer.enable(Functionality::IntegerRangeAnalysis);
    simplifier.simplify(expr);
}
// Out of the scope of `IntegerRangeAnalysis`, but
// `ConstantFolding` is still enabled.

That way, we could avoid the duplication of simplification rules across the codebase, but without making it an all-or-nothing choice between hand-writing duplicate rules and bringing in everything that is available.

@slyubomirsky
Copy link
Contributor

I like the idea about having control over the amount of simplification, for what it's worth. That said, I would be surprised if the arithmetic analyzer turns out to be a performance bottleneck (do we have evidence of this being the case?).

@Lunderberg
Copy link
Contributor Author

The only time I've seen the analyzer become a bottleneck is when I wrote some accidentally-exponential simplifications of binary operations. There have been a few cases where additional allocations performed as a result of recursive simplifications were noticably slower, but not a bottleneck.

@tqchen tqchen deleted the branch apache:unity March 29, 2024 12:18
@tqchen tqchen closed this Mar 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants