diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp index cb931ecf44d4..1ac6dcb04bee 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::iree_compiler::IREE::LinalgExt { @@ -19,29 +20,50 @@ namespace mlir::iree_compiler::IREE::LinalgExt { #define GEN_PASS_DEF_DECOMPOSEIM2COLPASS #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc" -namespace { - -/// Pattern to decompose the tiled im2col op. -struct DecomposeIm2col : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +static LogicalResult decomposeIm2col(Im2colOp im2colOp, RewriterBase &rewriter, + bool unroll) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(im2colOp); + FailureOr> decomposedIm2col = + im2colOp.decomposeOperation(rewriter); + if (failed(decomposedIm2col)) { + return failure(); + } + rewriter.replaceOp(im2colOp, decomposedIm2col.value().front()); + if (!unroll) { + return success(); + } - LogicalResult matchAndRewrite(Im2colOp im2colOp, - PatternRewriter &rewriter) const override { - FailureOr> decomposedIm2col = - im2colOp.decomposeOperation(rewriter); - if (failed(decomposedIm2col)) { + // Unroll the loop nest created by the im2col op decomposition. + auto outerLoop = decomposedIm2col.value().front().getDefiningOp(); + assert(outerLoop && + "expected im2col op decomposition to produce scf.for loop nest."); + SmallVector loopNest({outerLoop}); + while (auto innerLoop = + outerLoop.getYieldedValues()[0].getDefiningOp()) { + loopNest.push_back(innerLoop); + outerLoop = innerLoop; + } + for (auto loop : llvm::reverse(loopNest)) { + std::optional ub = getConstantIntValue(loop.getUpperBound()); + if (!ub.has_value() || ub.value() == 1) { + continue; + } + rewriter.setInsertionPoint(loop); + if (failed(mlir::loopUnrollByFactor(loop, ub.value()))) { + loop.emitOpError("failed to unroll loop"); return failure(); } - rewriter.replaceOp(im2colOp, decomposedIm2col.value().front()); - return success(); } -}; - -} // namespace + return success(); +} namespace { struct DecomposeIm2colPass final : impl::DecomposeIm2colPassBase { + using impl::DecomposeIm2colPassBase< + DecomposeIm2colPass>::DecomposeIm2colPassBase; + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert< affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect, @@ -54,8 +76,18 @@ struct DecomposeIm2colPass final void DecomposeIm2colPass::runOnOperation() { MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + SmallVector candidates; + funcOp->walk([&](Im2colOp op) { candidates.push_back(op); }); + IRRewriter rewriter(context); + for (auto im2colOp : candidates) { + if (failed(decomposeIm2col(im2colOp, rewriter, unroll))) { + return signalPassFailure(); + } + } + RewritePatternSet patterns(context); - patterns.add(context); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td index 57b484fd1e5a..ca06c4ffbf80 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td @@ -54,6 +54,10 @@ def DecomposeIm2colPass : InterfacePass<"iree-linalg-ext-decompose-im2col", "mlir::FunctionOpInterface"> { let summary = "Decomposes im2col ops into insert and extract slice ops"; + let options = [ + Option<"unroll", "unroll", "bool", /*default=*/"true", + "Unroll the resulting loop nest after decomposition.">, + ]; } def DecomposeWinogradTransformPass : diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir index ab627b6455cb..796ece5c3cd5 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir @@ -1,4 +1,5 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-im2col))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-im2col{unroll=false}))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-im2col{unroll=true}))" --split-input-file %s | FileCheck %s --check-prefix=CHECK-UNROLL #map = affine_map<(d0) -> (d0 * 4)> module { @@ -71,3 +72,65 @@ module { // CHECK: scf.yield %[[mLOOP]] : tensor<2x?x?xf32> // CHECK: } // CHECK: return %[[bLOOP]] : tensor<2x?x?xf32> + +// ----- + +#map = affine_map<(d0) -> (d0 * 4)> +module { + func.func @im2col_unrolled(%arg0: tensor<2x34x34x640xf32>, %m_off: index, %k: index) -> tensor<2x2x4xf32> { + %0 = tensor.empty() : tensor<2x2x4xf32> + %k_off = affine.apply #map(%k) + %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%arg0 : tensor<2x34x34x640xf32>) outs(%0 : tensor<2x2x4xf32>) -> tensor<2x2x4xf32> + return %7 : tensor<2x2x4xf32> + } +} +// CHECK-UNROLL-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)> +// CHECK-UNROLL-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)> +// CHECK-UNROLL-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)> +// CHECK-UNROLL: func.func @im2col_unrolled(%[[ARG0:.+]]: tensor<2x34x34x640xf32> +// CHECK-UNROLL-SAME: %[[mOFF:.+]]: index, %[[K:.+]]: index) +// CHECK-UNROLL-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-UNROLL-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-UNROLL: %[[OUT_TILE:.+]] = tensor.empty() : tensor<2x2x4xf32> + +// First iteration +// +// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]] +// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C0]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<4xf32>) outs(%[[OUT_SLICE]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK-UNROLL: %[[INSERT0:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32> into tensor<2x2x4xf32> + +// Second iteration +// +// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]] +// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C1]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C1]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<4xf32>) outs(%[[OUT_SLICE]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK-UNROLL: %[[INSERT1:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32> into tensor<2x2x4xf32> + +// Third iteration +// +// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]] +// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C0]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<4xf32>) outs(%[[OUT_SLICE]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK-UNROLL: %[[INSERT2:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32> into tensor<2x2x4xf32> + +// Fourth iteration +// +// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]] +// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C1]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C1]])[%[[mOFF]], %[[K]]] +// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<4xf32> +// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<4xf32>) outs(%[[OUT_SLICE]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK-UNROLL: %[[INSERT3:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32> into tensor<2x2x4xf32> + +// CHECK-UNROLL: return %[[INSERT3]] : tensor<2x2x4xf32>