diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index a1106217e2af..168518e3d5c0 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -530,11 +530,139 @@ class FoldAtenUnsqueezePattern : public OpRewritePattern { none, none, none, none); return success(); } + auto squeezeOp = op.getSelf().getDefiningOp(); + if (squeezeOp && resultTy.getSizes().size() == 1) { + rewriter.replaceOp(op, squeezeOp.getSelf()); + return success(); + } return failure(); } }; } // namespace + +namespace { +// This is a specific pattern for converting views like [?,...,?,lastDim] -> +// [?,...,?,factor0,factor1] to unflatten, and views like +// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is +// possible to infer that all but last shared dim match +// TODO: move this to an actual canonicalizer for view after deleting the +// conflicting decompositions for flatten/unflatten -> view. +class CanonicalizeAtenViewPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenViewOp op, + PatternRewriter &rewriter) const override { + SmallVector viewSizes; + if (failed(getListOperands(op.getSize(), viewSizes))) + return rewriter.notifyMatchFailure( + op, "view size must be from a list construct"); + auto selfTy = dyn_cast(op.getSelf().getType()); + if (!selfTy || !selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "missing input type or sizes"); + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasSizes() || + resultTy.getSizes().size() != viewSizes.size()) + return rewriter.notifyMatchFailure(op, "missing result type or sizes"); + int64_t inRank = selfTy.getSizes().size(); + int64_t outRank = resultTy.getSizes().size(); + + SmallVector sizes(selfTy.getSizes()); + int64_t endMatchingDim = -1; + // input sizes vs. provided view sizes comparison loop + for (int64_t i = 0; i < std::min(outRank, inRank); i++) { + int64_t providedSize; + bool providedStatic = + matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize)); + // if sizes[i] is static, it must match a constant in viewSizes[i] + if (sizes[i] != Torch::kUnknownSize) { + if (!providedStatic) + return rewriter.notifyMatchFailure( + op, "unsupported: found static input dim, but unable to match " + "provided view size on a constant. See position : " + + std::to_string(i)); + if (providedSize != sizes[i]) { + endMatchingDim = i; + break; + } + continue; + } + // the remaining assumes sizes[i] is dynamic + // if provided dim is static, we can't verify it is a flatten/unflatten + // unless -1 + if (i == outRank - 1 && providedStatic && providedSize == -1) { + endMatchingDim = i; + break; + } + if (providedStatic) + return rewriter.notifyMatchFailure( + op, "unexpected static view dim corresponding to dynamic input dim " + "at position : " + + std::to_string(i)); + auto sizeIntOp = viewSizes[i].getDefiningOp(); + // if we don't have a size int op on self, fail + if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf()) + return rewriter.notifyMatchFailure( + op, "expected dynamic view dim to come from a corresponding " + "size.int op. See position : " + + std::to_string(i)); + int64_t dim; + // if the dim of the size int op doesn't match, fail + if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) || + dim != i) + return rewriter.notifyMatchFailure( + op, + "size int op dim cannot be matched to current dim at position : " + + std::to_string(i)); + // passing the previous checks means viewSizes[i] = aten.size.int(self, + // i), so continue + } + // if all dims match and the ranks are equal, fold + if (endMatchingDim == -1 && inRank == outRank) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + if (endMatchingDim > -1 && inRank > outRank) { + // only support flattening last dim + if (endMatchingDim != outRank - 1) + return rewriter.notifyMatchFailure( + op, "unimplemented: output has more than back dim mismatching"); + // flatten + Value start = + rewriter.create(op.getLoc(), endMatchingDim); + Value end = + rewriter.create(op.getLoc(), inRank - 1); + rewriter.replaceOpWithNewOp( + op, resultTy, op.getSelf(), start, end); + return success(); + } + if (endMatchingDim > -1 && inRank < outRank) { + // only support unflattening last dim + if (endMatchingDim != inRank - 1) + return rewriter.notifyMatchFailure( + op, "unimplemented: input has more than back dim mismatching"); + // unflatten + Value dim = + rewriter.create(op.getLoc(), endMatchingDim); + Value primList = rewriter.create( + op.getLoc(), op.getSize().getType(), + ArrayRef(viewSizes.begin() + endMatchingDim, viewSizes.end())); + rewriter.replaceOpWithNewOp( + op, resultTy, op.getSelf(), dim, primList); + return success(); + } + // examples that might reach this: + // input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants) + // input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes) + // input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes) + return rewriter.notifyMatchFailure( + op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) + + ", inRank=" + std::to_string(inRank) + + ", outRank=" + std::to_string(outRank)); + } +}; +} // namespace + namespace { template class RemoveUnusedPattern : public OpRewritePattern { public: @@ -561,18 +689,24 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns - .insert, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern>(context); + patterns.insert, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern>(context); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index db8d71576ca3..17f786a8215b 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc %slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32> return %slice : !torch.vtensor<[2],si32> } + + +// ----- + +// CHECK-LABEL: @view_as_flatten_static +func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32> + %int1024 = torch.constant.int 1024 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list -> !torch.vtensor<[?,?,1024],f32> + return %3 : !torch.vtensor<[?,?,1024],f32> +} + + +// ----- + +// CHECK-LABEL: @view_as_unflatten_static +func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16 + // CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,16,64],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32> + %int16 = torch.constant.int 16 + %int64 = torch.constant.int 64 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list -> !torch.vtensor<[?,?,16,64],f32> + return %3 : !torch.vtensor<[?,?,16,64],f32> +} + + +// ----- + +// CHECK-LABEL: @view_as_flatten_dynamic +func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32> + %int-1 = torch.constant.int -1 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?],f32> +} + + +// ----- + +// CHECK-LABEL: @unsqueeze_squeeze_combo +func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int { + // CHECK: %int0 = torch.constant.int 0 + // CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + // CHECK: return %0 : !torch.int + %0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64> + %4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64> + %7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list + %12 = torch.aten.cat %11, %int0 : !torch.list, !torch.int -> !torch.vtensor<[3],si64> + %13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int + return %14 : !torch.int +}