From 2665ed343b19713ba5c1c555b2366a93de8b9d2b Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:16:45 -0500 Subject: [PATCH] adds a few common patterns to scalarize shapes pass (#3779) This patch adds two things: 1. support for folding scalar patterns like [1]---squeeze--->[] ---unsqueeze--->[1]. 2. a canonicalizer for aten.view that applies when we can statically or dynamically (through the scalarized view shapes) infer that it is a flatten or unflatten op in the last dim. I'm not sure if this is the right place to be adding such a view canonicalizer. Catastrophically, there is a decomposition from flatten and unflatten into aten.view. Until this gets deleted (and it definitely should be deleted), I felt like this would be an appropriate temporary home. We run scalarize shapes after lowering to the backend contract (i.e., decomposing), and scalarize shapes is required to be able to infer dynamic dims coming from size int ops. --- .../Torch/Transforms/ScalarizeShapes.cpp | 158 ++++++++++++++++-- test/Dialect/Torch/scalarize-shapes.mlir | 88 ++++++++++ 2 files changed, 234 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index a1106217e2af..168518e3d5c0 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -530,11 +530,139 @@ class FoldAtenUnsqueezePattern : public OpRewritePattern { none, none, none, none); return success(); } + auto squeezeOp = op.getSelf().getDefiningOp(); + if (squeezeOp && resultTy.getSizes().size() == 1) { + rewriter.replaceOp(op, squeezeOp.getSelf()); + return success(); + } return failure(); } }; } // namespace + +namespace { +// This is a specific pattern for converting views like [?,...,?,lastDim] -> +// [?,...,?,factor0,factor1] to unflatten, and views like +// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is +// possible to infer that all but last shared dim match +// TODO: move this to an actual canonicalizer for view after deleting the +// conflicting decompositions for flatten/unflatten -> view. +class CanonicalizeAtenViewPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenViewOp op, + PatternRewriter &rewriter) const override { + SmallVector viewSizes; + if (failed(getListOperands(op.getSize(), viewSizes))) + return rewriter.notifyMatchFailure( + op, "view size must be from a list construct"); + auto selfTy = dyn_cast(op.getSelf().getType()); + if (!selfTy || !selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "missing input type or sizes"); + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasSizes() || + resultTy.getSizes().size() != viewSizes.size()) + return rewriter.notifyMatchFailure(op, "missing result type or sizes"); + int64_t inRank = selfTy.getSizes().size(); + int64_t outRank = resultTy.getSizes().size(); + + SmallVector sizes(selfTy.getSizes()); + int64_t endMatchingDim = -1; + // input sizes vs. provided view sizes comparison loop + for (int64_t i = 0; i < std::min(outRank, inRank); i++) { + int64_t providedSize; + bool providedStatic = + matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize)); + // if sizes[i] is static, it must match a constant in viewSizes[i] + if (sizes[i] != Torch::kUnknownSize) { + if (!providedStatic) + return rewriter.notifyMatchFailure( + op, "unsupported: found static input dim, but unable to match " + "provided view size on a constant. See position : " + + std::to_string(i)); + if (providedSize != sizes[i]) { + endMatchingDim = i; + break; + } + continue; + } + // the remaining assumes sizes[i] is dynamic + // if provided dim is static, we can't verify it is a flatten/unflatten + // unless -1 + if (i == outRank - 1 && providedStatic && providedSize == -1) { + endMatchingDim = i; + break; + } + if (providedStatic) + return rewriter.notifyMatchFailure( + op, "unexpected static view dim corresponding to dynamic input dim " + "at position : " + + std::to_string(i)); + auto sizeIntOp = viewSizes[i].getDefiningOp(); + // if we don't have a size int op on self, fail + if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf()) + return rewriter.notifyMatchFailure( + op, "expected dynamic view dim to come from a corresponding " + "size.int op. See position : " + + std::to_string(i)); + int64_t dim; + // if the dim of the size int op doesn't match, fail + if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) || + dim != i) + return rewriter.notifyMatchFailure( + op, + "size int op dim cannot be matched to current dim at position : " + + std::to_string(i)); + // passing the previous checks means viewSizes[i] = aten.size.int(self, + // i), so continue + } + // if all dims match and the ranks are equal, fold + if (endMatchingDim == -1 && inRank == outRank) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + if (endMatchingDim > -1 && inRank > outRank) { + // only support flattening last dim + if (endMatchingDim != outRank - 1) + return rewriter.notifyMatchFailure( + op, "unimplemented: output has more than back dim mismatching"); + // flatten + Value start = + rewriter.create(op.getLoc(), endMatchingDim); + Value end = + rewriter.create(op.getLoc(), inRank - 1); + rewriter.replaceOpWithNewOp( + op, resultTy, op.getSelf(), start, end); + return success(); + } + if (endMatchingDim > -1 && inRank < outRank) { + // only support unflattening last dim + if (endMatchingDim != inRank - 1) + return rewriter.notifyMatchFailure( + op, "unimplemented: input has more than back dim mismatching"); + // unflatten + Value dim = + rewriter.create(op.getLoc(), endMatchingDim); + Value primList = rewriter.create( + op.getLoc(), op.getSize().getType(), + ArrayRef(viewSizes.begin() + endMatchingDim, viewSizes.end())); + rewriter.replaceOpWithNewOp( + op, resultTy, op.getSelf(), dim, primList); + return success(); + } + // examples that might reach this: + // input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants) + // input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes) + // input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes) + return rewriter.notifyMatchFailure( + op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) + + ", inRank=" + std::to_string(inRank) + + ", outRank=" + std::to_string(outRank)); + } +}; +} // namespace + namespace { template class RemoveUnusedPattern : public OpRewritePattern { public: @@ -561,18 +689,24 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns - .insert, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern>(context); + patterns.insert, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern>(context); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index db8d71576ca3..17f786a8215b 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc %slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32> return %slice : !torch.vtensor<[2],si32> } + + +// ----- + +// CHECK-LABEL: @view_as_flatten_static +func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32> + %int1024 = torch.constant.int 1024 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list -> !torch.vtensor<[?,?,1024],f32> + return %3 : !torch.vtensor<[?,?,1024],f32> +} + + +// ----- + +// CHECK-LABEL: @view_as_unflatten_static +func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16 + // CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,16,64],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32> + %int16 = torch.constant.int 16 + %int64 = torch.constant.int 64 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list -> !torch.vtensor<[?,?,16,64],f32> + return %3 : !torch.vtensor<[?,?,16,64],f32> +} + + +// ----- + +// CHECK-LABEL: @view_as_flatten_dynamic +func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32> + %int-1 = torch.constant.int -1 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?],f32> +} + + +// ----- + +// CHECK-LABEL: @unsqueeze_squeeze_combo +func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int { + // CHECK: %int0 = torch.constant.int 0 + // CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + // CHECK: return %0 : !torch.int + %0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64> + %4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64> + %7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list + %12 = torch.aten.cat %11, %int0 : !torch.list, !torch.int -> !torch.vtensor<[3],si64> + %13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int + return %14 : !torch.int +}