Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Rework Scalarize Shapes Pass #3799

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

zjgarvey
Copy link
Collaborator

@zjgarvey zjgarvey commented Oct 16, 2024

Putting this up here for comments and visibility.

Purpose:

  1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes.
  2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation.
  3. Report failed scalarization patterns for easier debugging (Not yet implemented).

With these changes, some reworking of the conversions themselves will be necessary.

  1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern.
    2. All ValueTensorLiteralOps in the work list and of reasonable size should be scalarized. The lack of scalarization for this op is a significant issue for the success of this pass and requires adding cumbersome logic in most other scalarization conversions for determining whether to get elements from a propagated list or to materialize constants from a value tensor literal. Now that the worklist is restricted only to shape calculation related ops, choosing to scalarize these dense resources won't be problematic.
  2. Rework getListFromTensor to modify a SmallVector<OpFoldResult> to allow processing value tensor literals without materializing the ints.

RFC 1:

Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions:

  1. Accepts a list of constant ints, LIST, as an input
  2. The role of LIST is not shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc.
  3. The LIST is not gotten entirely from scalars already.

If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints.

RFC 2:

What should I use to report failed scalarization?

Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op.

@zjgarvey
Copy link
Collaborator Author

zjgarvey commented Oct 17, 2024

Okay, scalarizing value tensor literals does not work, since they get folded back to value tensor literals.

Going to move towards having getListFromTensor modify a SmallVector<OpFoldResult> instead of SmallVector<Value>.

});

GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so this is actually the key thing to this algorithm and needs some explanation: You are relying on patterns that, given some anchor, always rewrite to produce some new ops that (if there are further ops above) introduce an op that is a new anchor that goes on the worklist. So this is still a bottom up rewrite where each step is producing a locally correct graph transformation (at the cost of each step needing to create a number of new ops that will just be thrown away on the next pattern match).

This is a bit different from what I had described: do a first traversal to identify all dominators that end with an anchor. Then transform top-down, where each pattern produces an invalid graph transformation that does not need to materialize throw away ops at each step. In this approach, you would make one pass top down through the IR doing transformations and noting errors for anything that could not be converted.

Given the code that already exists for this and that the top-down way is definitely harder to write, doing it this way is fine. The cost will be more garbage/allocations that scale with the number of changes you have to make. And it is harder to do precise error handling. The up-side, though, is that if a pattern fails to fully scalarize, the graph is in a consistent but suboptimal state and can still be used.

So to answer the two questions you raised in the PR description:

  1. PrimListOfInts may be an ok anchor. In my mind, that was an anchor but only if it was consumed as an index-oriented operand of an op. It may be a distinction without a difference in the end. I'm just paranoid because I know that in these kinds of graphs, there are definitely weird bottlenecks of such things that need more discrimination.
  2. For error handling, if you don't want to do another walk, create your "Propagate*" pattern classes to take a reference to some PropagationState&. For things your patterns encounter that indicate a failure to convert that you want to report, add then to some set in there. Then at the end of your pass, emit one warning with multiple notes indicating the ops that could not be converted. The reason you want your patterns themselves to be silent is that patterns can be run multiple times during iteration and it is impossible to have precise control of your diagnostic behavior as a result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants