diff --git a/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir new file mode 100644 index 000000000..1b910e4a3 --- /dev/null +++ b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir @@ -0,0 +1,67 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + + // Definition for the batch matrix multiplication function + func.func @buddy_batchmatmul_f32(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul + ins(%A, %B: memref, memref) + outs(%C: memref) + return + } + + func.func @main(){ + // Set up dims. + %cBatch = arith.constant 10:index + %cM = arith.constant 2 : index + %cN = arith.constant 5 : index + %cK = arith.constant 4 : index + + // Set Init Value. + %cf1 = arith.constant 1.0 : f32 + %cf2 = arith.constant 2.0 : f32 + %c0 = arith.constant 0.0 : f32 + + %A = memref.alloc(%cBatch,%cM, %cK) : memref + %B = memref.alloc(%cBatch,%cK, %cN) : memref + %C = memref.alloc(%cBatch,%cM, %cN) : memref + + linalg.fill + ins(%cf1 : f32) + outs(%A:memref) + + linalg.fill + ins(%cf2 : f32) + outs(%B:memref) + + linalg.fill + ins(%c0 : f32) + outs(%C:memref) + + call @buddy_batchmatmul_f32(%A, %B, %C) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C = memref.cast %C : memref to memref<*xf32> + call @printMemrefF32(%print_C) : (memref<*xf32>) -> () + + memref.dealloc %C : memref + memref.dealloc %B : memref + memref.dealloc %A : memref + return + } +} diff --git a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir new file mode 100644 index 000000000..2c8cc171e --- /dev/null +++ b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func @alloc_2d_filled_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %arg5 = %c0 to %arg0 step %c1 { + scf.for %arg6 = %c0 to %arg1 step %c1 { + scf.for %arg7 = %c0 to %arg2 step %c1 { + scf.for %arg8 = %c0 to %arg3 step %c1 { + %iarg8=arith.index_cast %arg8 : index to i32 + %loopf= arith.sitofp %iarg8 : i32 to f32 + memref.store %loopf, %0[%arg5, %arg6, %arg7, %arg8] : memref + } + } + } + } + return %0 : memref + } + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nhwc_fhwc ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) + return + } + func.func @main() { + // Intput(image, filter) and output value. + %cst = arith.constant 0.500000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + + %current_image_n = arith.constant 2 : index + %current_image_c = arith.constant 18 : index + %current_image_h = arith.constant 8 : index + %current_image_w = arith.constant 8 : index + + %current_filter_f = arith.constant 2 : index + %current_filter_c = arith.constant 18 : index + %current_filter_h = arith.constant 4 : index + %current_filter_w = arith.constant 4 : index + + %current_output_n = arith.constant 2 : index + %current_output_c = arith.constant 2 : index + %current_output_h = arith.constant 5 : index + %current_output_w = arith.constant 5 : index + + // Image. + %image = call @alloc_2d_filled_f32(%current_image_n,%current_image_h, %current_image_w, %current_image_c, %cst) : (index, index, index, index, f32) -> memref + // Filter. + %filter = call @alloc_2d_filled_f32(%current_filter_f, %current_filter_h, %current_filter_w,%current_filter_c, %cst) : (index, index, index, index, f32) -> memref + // Output. + %output = call @alloc_2d_filled_f32(%current_output_n, %current_output_h, %current_output_w,%current_output_c, %cst_0) : (index, index, index, index, f32) -> memref + + call @conv_2d_nhwc_fhwc(%image, %filter, %output) : (memref, memref, memref) -> () + + %3 = memref.cast %output : memref to memref<*xf32> + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [2, 2, 4, 4] strides = [32, 16, 4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ], + // CHECK-NEXT: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ] + // CHECK-SAME: ], + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ], + // CHECK-NEXT: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ] + // CHECK-SAME: ] + // CHECK-SAME: ] + call @printMemrefF32(%3) : (memref<*xf32>) -> () + + memref.dealloc %output : memref + memref.dealloc %image : memref + memref.dealloc %filter : memref + return + } +} + diff --git a/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir new file mode 100644 index 000000000..510835a27 --- /dev/null +++ b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir @@ -0,0 +1,71 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + + func.func @depthwise_conv_2d_nhwc_hwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.depthwise_conv_2d_nhwc_hwc + {dilations = dense<[1,1]> : tensor<2xi64>, strides = dense<[1,1]> : tensor<2xi64>} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return + } + + func.func @main() { + // Constants for input image, filter, and output sizes. + %cst = arith.constant 0.500000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cf1 = arith.constant 1.0 : f32 + + %image_n = arith.constant 2 : index + %image_h = arith.constant 8 : index + %image_w = arith.constant 8 : index + %image_c = arith.constant 18 : index + + %filter_h = arith.constant 4 : index + %filter_w = arith.constant 4 : index + %filter_c = arith.constant 18 : index + + %output_n = arith.constant 2 : index + %output_h = arith.constant 5 : index + %output_w = arith.constant 5 : index + %output_c = arith.constant 18 : index + + %image = memref.alloc(%image_n,%image_h,%image_w,%image_c) : memref + %filter = memref.alloc(%filter_h,%filter_w,%filter_c) : memref + %output = memref.alloc(%output_n,%output_h,%output_w,%output_c) : memref + + // Allocate and fill image, filter, and output. + linalg.fill + ins(%cf1 : f32) + outs(%image:memref) + + linalg.fill + ins(%cf1 : f32) + outs(%filter:memref) + linalg.fill + ins(%cf1 : f32) + outs(%output:memref) + + // Call depthwise convolution. + call @depthwise_conv_2d_nhwc_hwc(%image, %filter, %output) : (memref, memref, memref) -> () + + %output_cast = memref.cast %output : memref to memref<*xf32> + + // Print the output. + call @printMemrefF32(%output_cast) : (memref<*xf32>) -> () + + // Deallocate memory. + memref.dealloc %output : memref + memref.dealloc %image : memref + memref.dealloc %filter : memref + return + } +} diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index ffd6888cc..e25702201 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -60,6 +60,45 @@ linalg-conv2d-tiling-run: -convert-func-to-llvm -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} +linalg-conv2d_nhwc_fhwc-optimize-lower: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ + -conv-nhwc-fhwc-optimize="vec-size=16" \ + -o ./log.mlir + +linalg-conv2d_nhwc_fhwc-tile-optimize-lower: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ + -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ + -o ./log.mlir + +linalg-conv2d_nhwc_fhwc-optimize-run: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ + -conv-nhwc-fhwc-optimize="vec-size=16" \ + -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-conv2d_nhwc_fhwc-tile-optimize-run: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ + -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ + -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-depthwise_conv_2d_nhwc_hwc-optimize-lower: + @${BUDDY_OPT} linalg-depthwise_conv_2d_nhwc_hwc.mlir \ + -depthwise-conv-nhwc-hwc-optimize="vec-size=16" \ + -o ./log.mlir + +linalg-depthwise_conv_2d_nhwc_hwc-optimize-run: + @${BUDDY_OPT} linalg-depthwise_conv_2d_nhwc_hwc.mlir \ + -depthwise-conv-nhwc-hwc-optimize="vec-size=16" \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + linalg-generic-lower: @${MLIR_OPT} ./linalg-generic.mlir \ -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ @@ -177,6 +216,16 @@ linalg-batch-matmul-optimize-lower: -batchmatmul-optimize="vector-size=64" \ -o ./log.mlir +linalg-batch-matmul-tile-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ + -o ./log.mlir + +linalg-batch-matmul-scf-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-scf-optimize="vector-size=64" \ + -o ./log.mlir + linalg-batch-matmul-optimize-translate: @${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-optimize="vector-size=64" \ diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index 99254e410..cfe12a8d6 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -14,3 +14,4 @@ add_subdirectory(LowerLinalgToGemmini) add_subdirectory(SchedulingOnDevices) add_subdirectory(LowerSche) add_subdirectory(FuncBufferize) +add_subdirectory(DepthwiseConvOptimization) diff --git a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt index fc88a92ef..336c95a30 100644 --- a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt @@ -1,3 +1,5 @@ add_mlir_library(ConvOptimization ConvOptimize.cpp + ConvNhwcFhwcOptimize.cpp + ConvNhwcFhwcOptimizeTile.cpp ) diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp new file mode 100644 index 000000000..e4bc67e36 --- /dev/null +++ b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp @@ -0,0 +1,276 @@ +//====- ConvNhwcFhwcOptimize.cpp----------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Conv2DNhwcFhwcOp optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class ConvNhwcFhwcOptimizePattern : public ConversionPattern { +public: + explicit ConvNhwcFhwcOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC + Value IC = rewriter.create(loc, input, 3); // IC + Value FH = rewriter.create(loc, filter, 1); // FH + Value FW = rewriter.create(loc, filter, 2); // FW + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW,OC + auto outputForAllOp = rewriter.create( + loc, SmallVector({N, OH, OW, OC}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + Value ivOH = loopIndices[1]; // Index for the second dimension OH + Value ivOW = loopIndices[2]; // Index for the third dimension OW + Value ivOC = loopIndices[3]; // Index for the third dimension OC + + Value addRes = nestedBuilder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + // IC + auto forOp = nestedBuilder.create( + nestedLoc, c0, IC, vecSizeValue, ValueRange{addRes}, + [&](OpBuilder &builder, Location loc, Value ivIC, + ValueRange iargs) { + Value tVec; + if (isa(elemTy)) { + tVec = builder.create(loc, vecTy, + zeroElementType); + } else { + tVec = builder.create(loc, vecTy, + zeroElementType); + } + + Value remainLen = builder.create( + loc, + AffineMap::get(2, 1, {-d0 + s0, d1}, builder.getContext()), + ValueRange{ivIC, vecSizeValue, IC}); + Value remainMask = builder.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{remainLen}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = ivFW; + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, ivIC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{ivOC, rowFilter, columnFilter, + ivIC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + auto reduceVecOp = builder.create( + loc, vector::CombiningKind::ADD, forOp.getResult(0)); + auto maskedOp = + cast(mlir::vector::maskOperation( + builder, reduceVecOp, remainMask)); + Value reduceVec = maskedOp->getResult(0); + Value addNext; + if (isa(elemTy)) { + addNext = + builder.create(loc, iargs[0], reduceVec); + } else { + addNext = + builder.create(loc, iargs[0], reduceVec); + } + builder.create(loc, ValueRange{addNext}); + }); + + nestedBuilder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConvNhwcFhwcOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class ConvNhwcFhwcOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvNhwcFhwcOptimizePass) + StringRef getArgument() const final { return "conv-nhwc-fhwc-optimize"; } + StringRef getDescription() const final { + return "Conv2d NHWC FHWC optimize."; + } + ConvNhwcFhwcOptimizePass() = default; + ConvNhwcFhwcOptimizePass(const ConvNhwcFhwcOptimizePass &) {} + explicit ConvNhwcFhwcOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void ConvNhwcFhwcOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConvNhwcFhwcOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp new file mode 100644 index 000000000..db812aceb --- /dev/null +++ b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp @@ -0,0 +1,342 @@ +//====- ConvNhwcFhwcOptimizeTile.cpp------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Conv2DNhwcFhwcOp tile optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class ConvNhwcFhwcTileOptimizePattern : public ConversionPattern { +public: + explicit ConvNhwcFhwcTileOptimizePattern(MLIRContext *context, + int64_t vecSizeParam, + int64_t tilingOHParam, + int64_t tilingOWParam, + int64_t tilingOCParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + tilingOH = tilingOHParam; + tilingOW = tilingOWParam; + tilingOC = tilingOCParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC + Value IC = rewriter.create(loc, input, 3); // IC + Value FH = rewriter.create(loc, filter, 1); // FH + Value FW = rewriter.create(loc, filter, 2); // FW + + auto tilingUpperBound = + AffineMap::get(2, 1, {d0 + d1, s0}, rewriter.getContext()); + + Value stepOH = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOH)), OH); + Value stepOW = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOW)), OW); + Value stepOC = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOC)), OC); + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW,OC + rewriter.create( + loc, SmallVector{c0, c0, c0, c0}, + SmallVector({N, OH, OW, OC}), + SmallVector({c1, stepOH, stepOW, stepOC}), + ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + + Value ubOH = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[1], stepOH, + OH}); // ub for the second dimension OH + Value ubOW = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[2], stepOW, + OW}); // ub for the second dimension OW + Value ubOC = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[3], stepOC, + OC}); // ub for the second dimension OC + + rewriter.create( + loc, + SmallVector{loopIndices[1], loopIndices[2], + loopIndices[3]}, + SmallVector({ubOH, ubOW, ubOC}), + SmallVector({c1, c1, c1}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivOH = loopIndices[0]; // Index for the first dimension OH + Value ivOW = loopIndices[1]; // Index for the first dimension OW + Value ivOC = loopIndices[2]; // Index for the first dimension OC + + Value addRes = nestedBuilder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + // IC + auto forOp = nestedBuilder.create( + nestedLoc, c0, IC, vecSizeValue, ValueRange{addRes}, + [&](OpBuilder &builder, Location loc, Value ivIC, + ValueRange iargs) { + Value tVec; + if (isa(elemTy)) { + tVec = builder.create( + loc, vecTy, zeroElementType); + } else { + tVec = builder.create(loc, vecTy, + zeroElementType); + } + + Value remainLen = builder.create( + loc, + AffineMap::get(2, 1, {-d0 + s0, d1}, + builder.getContext()), + ValueRange{ivIC, vecSizeValue, IC}); + Value remainMask = builder.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{remainLen}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, + Value ivFW, ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get(2, 0, + d0 * strWidth + + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, + ivIC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{ivOC, rowFilter, columnFilter, + ivIC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = + builder.create(loc, iVec, + fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create( + loc, ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + auto reduceVecOp = builder.create( + loc, vector::CombiningKind::ADD, forOp.getResult(0)); + auto maskedOp = + cast(mlir::vector::maskOperation( + builder, reduceVecOp, remainMask)); + Value reduceVec = maskedOp->getResult(0); + Value addNext; + if (isa(elemTy)) { + addNext = builder.create(loc, iargs[0], + reduceVec); + } else { + addNext = builder.create(loc, iargs[0], + reduceVec); + } + builder.create(loc, ValueRange{addNext}); + }); + + nestedBuilder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + nestedBuilder.create(nestedLoc); + }); + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; + int64_t tilingOH; + int64_t tilingOW; + int64_t tilingOC; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConvNhwcFhwcTileOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class ConvNhwcFhwcTileOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvNhwcFhwcTileOptimizePass) + StringRef getArgument() const final { return "conv-nhwc-fhwc-tile-optimize"; } + StringRef getDescription() const final { + return "Conv2d NHWC FHWC optimize with Tile."; + } + ConvNhwcFhwcTileOptimizePass() = default; + ConvNhwcFhwcTileOptimizePass(const ConvNhwcFhwcTileOptimizePass &) {} + explicit ConvNhwcFhwcTileOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; + Option tilingOH{*this, "tiling-height", + llvm::cl::desc("tiling the output height."), + llvm::cl::init(1)}; + Option tilingOW{*this, "tiling-width", + llvm::cl::desc("tiling the output width."), + llvm::cl::init(1)}; + Option tilingOC{*this, "tiling-channel", + llvm::cl::desc("tiling the output channel."), + llvm::cl::init(1)}; +}; +} // end anonymous namespace. + +void ConvNhwcFhwcTileOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize, tilingOH, + tilingOW, tilingOC); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConvNhwcFhwcTileOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp b/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp index 55c876dd6..918a1388d 100644 --- a/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp +++ b/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp @@ -122,8 +122,7 @@ class GEMMPointwiseConvPattern : public ConversionPattern { namespace { class PointwiseConvToGemmPass - : public PassWrapper> { + : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PointwiseConvToGemmPass) StringRef getArgument() const final { return "pointwise-conv-to-gemm"; } @@ -144,14 +143,20 @@ class PointwiseConvToGemmPass void PointwiseConvToGemmPass::runOnOperation() { MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); ConversionTarget target(*context); - target.addLegalDialect(); + target + .addLegalDialect(); target.addLegalOp(); target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); } namespace mlir { diff --git a/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt b/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt new file mode 100644 index 000000000..8493e2a60 --- /dev/null +++ b/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_library(DepthwiseConvOptimization + DepthwiseConvNhwcHwc.cpp + ) diff --git a/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp b/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp new file mode 100644 index 000000000..04bf76f76 --- /dev/null +++ b/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp @@ -0,0 +1,331 @@ +//====- DepthwiseConvNhwcHwc.cpp +//--------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the DepthwiseConvNhwcHwc optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class DepthwiseConv2DNhwcHwcOptimizePattern : public ConversionPattern { +public: + explicit DepthwiseConv2DNhwcHwcOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::DepthwiseConv2DNhwcHwcOp::getOperationName(), + 1, context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + Value zeroElementTypeVec; + if (isa(elemTy)) { + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + } else { + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + } + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC/FC/IC + + Value applyOC = rewriter.create( + loc, AffineMap::get(1, 0, d0.floorDiv(vecSize) * vecSize), OC); + Value tailLength = rewriter.create( + loc, AffineMap::get(1, 0, d0 % vecSize), ValueRange{OC}); + Value maskVector = rewriter.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{tailLength}); + + Value FH = rewriter.create(loc, filter, 0); // FH + Value FW = rewriter.create(loc, filter, 1); // FW + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW + auto outputForAllOp = rewriter.create( + loc, SmallVector({N, OH, OW}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + Value ivOH = loopIndices[1]; // Index for the second dimension OH + Value ivOW = loopIndices[2]; // Index for the third dimension OW + // OC + nestedBuilder.create( + nestedLoc, c0, applyOC, vecSizeValue, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value ivOC, + ValueRange iargs) { + Value tVec = builder.create( + loc, vecTy, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, ivOC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{rowFilter, columnFilter, ivOC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + builder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + + builder.create(loc, ValueRange{std::nullopt}); + }); + + // applyOC + Value condition = nestedBuilder.create( + loc, arith::CmpIPredicate::sgt, tailLength, c0); + nestedBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value tVec = builder.create( + loc, vecTy, output, ValueRange{ivN, ivOH, ivOW, applyOC}, + maskVector, zeroElementTypeVec); + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, applyOC}, + maskVector, zeroElementTypeVec); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{rowFilter, columnFilter, applyOC}, + maskVector, zeroElementTypeVec); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + builder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, applyOC}, + maskVector, forOp.getResult(0)); + builder.create(loc, ValueRange{std::nullopt}); + }); + + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// DepthwiseConv2DNhwcHwcOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class DepthwiseConv2DNhwcHwcOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + DepthwiseConv2DNhwcHwcOptimizePass) + StringRef getArgument() const final { + return "depthwise-conv-nhwc-hwc-optimize"; + } + StringRef getDescription() const final { + return "Depthwise Conv2d NHWC HWC optimize."; + } + DepthwiseConv2DNhwcHwcOptimizePass() = default; + DepthwiseConv2DNhwcHwcOptimizePass( + const DepthwiseConv2DNhwcHwcOptimizePass &) {} + explicit DepthwiseConv2DNhwcHwcOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void DepthwiseConv2DNhwcHwcOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerDepthwiseConv2DNhwcHwcOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp new file mode 100644 index 000000000..a3d079be2 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp @@ -0,0 +1,281 @@ +//===- BatchMatMulOptimize.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the batchmatmul scf vectorization optimization. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; +using namespace affine; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class BatchMatMuSCFOptimizePattern : public ConversionPattern { +private: + int64_t vecSize; + +public: + explicit BatchMatMuSCFOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Retrieve input tensors A, B, and C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Acquire the element type of input tensors. + Type elementType = A.getType().cast().getElementType(); + + // Define constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value cVecSize = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr d2 = rewriter.getAffineDimExpr(2); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); + + const Value zeroElementType = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, 0); + Value aRow = rewriter.create(loc, A, 1); + Value bCol = rewriter.create(loc, B, 2); + Value bRow = rewriter.create(loc, B, 1); + + VectorType vecTy = VectorType::get({vecSize}, elementType); + Value zeroElementTypeVec; + if (isa(elementType)) + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + else + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + // Calculate the length of the tail, which might not fit in a + // vector. + Value tailLength = rewriter.create( + loc, AffineMap::get(1, 0, d0 % vecSize), ValueRange{bCol}); + + // Generate a mask vector based on the tail length. + Value maskVector = rewriter.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{tailLength}); + + Value ApplyBCol = rewriter.create( + loc, AffineMap::get(1, 0, d0.floorDiv(vecSize) * vecSize), bCol); + + rewriter.create( + loc, SmallVector({c0}), + SmallVector({batch}), + SmallVector({c1}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &builder, Location loc, ValueRange loopIndices) { + Value loopVarBatchIdx = loopIndices[0]; + builder.create( + loc, c0, aRow, c1, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value loopVarRowOfA, + ValueRange iargs) { + builder.create( + loc, c0, bRow, c1, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value loopVarRowOfB, + ValueRange iargs) { + Value aElement = builder.create( + loc, A, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarRowOfB}); + Value aVec = builder.create( + loc, vecTy, aElement); + builder.create( + loc, c0, ApplyBCol, cVecSize, + ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, + Value loopVarColOfB, ValueRange iargs) { + Value bVec = builder.create( + loc, vecTy, B, + ValueRange{loopVarBatchIdx, loopVarRowOfB, + loopVarColOfB}); + + Value cVec = builder.create( + loc, vecTy, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarColOfB}); + Value computedVec; + + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, cVec); + } else { + computedVec = builder.create( + loc, aVec, bVec, cVec); + } + builder.create( + loc, computedVec, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarColOfB}); + builder.create( + loc, ValueRange{std::nullopt}); + }); + Value condition = builder.create( + loc, arith::CmpIPredicate::sgt, tailLength, c0); + builder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + Value bVec = builder.create( + loc, vecTy, B, + ValueRange{loopVarBatchIdx, loopVarRowOfB, + ApplyBCol}, + maskVector, zeroElementTypeVec); + + Value cVec = builder.create( + loc, vecTy, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + ApplyBCol}, + maskVector, zeroElementTypeVec); + + Value computedVec; + + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, cVec); + } else { + computedVec = builder.create( + loc, aVec, bVec, cVec); + } + + builder.create( + loc, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + ApplyBCol}, + maskVector, computedVec); + builder.create(loc); + }); + builder.create(loc, + ValueRange{std::nullopt}); + }); + builder.create(loc, ValueRange{std::nullopt}); + }); + + builder.create(loc); + }); + + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BatchMatMuSCFOptimize +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class BatchMatMuSCFOptimize + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BatchMatMuSCFOptimize) + StringRef getArgument() const final { return "batchmatmul-scf-optimize"; } + StringRef getDescription() const final { + return "BatchMatMul SCF Optimization."; + } + BatchMatMuSCFOptimize() = default; + BatchMatMuSCFOptimize(const BatchMatMuSCFOptimize &) {} + explicit BatchMatMuSCFOptimize(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vector-size", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void BatchMatMuSCFOptimize::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} +// add to buddy-opt.cpp +namespace mlir { +namespace buddy { +void registerBatchMatMuSCFOptimize() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp new file mode 100644 index 000000000..91d10c645 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp @@ -0,0 +1,353 @@ +//===- BatchMatMulOptimize.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the batchmatmul tile optimization. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; +using namespace affine; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class BatchMatMulTileOptimizePattern : public ConversionPattern { +private: + int64_t vecSize, kernelM, kernelN; + +public: + explicit BatchMatMulTileOptimizePattern(MLIRContext *context, + int64_t vecSizeParam, + int64_t kernelMParam, + int64_t kernelNParam) + : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Retrieve input tensors A, B, and C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Acquire the element type of input tensors. + Type elementType = A.getType().cast().getElementType(); + ShapedType ATy = A.getType().cast(); + + // Define constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr d2 = rewriter.getAffineDimExpr(2); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + const AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + const AffineExpr s2 = rewriter.getAffineSymbolExpr(2); + + const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); + + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, 0); + Value M = rewriter.create(loc, A, 1); // aRow + Value K = rewriter.create(loc, B, 1); // bRow + Value N = rewriter.create(loc, B, 2); // bCol + + SmallVector reducedValues = llvm::to_vector<4>( + llvm::map_range(ArrayRef{}, + [](const LoopReduction &red) { return red.value; })); + + // Configs + int64_t kNLen = vecSize * kernelN; + + // Create the primary parallel batch level loop. + AffineParallelOp parallelBatchLoop = + rewriter.create( + loc, ValueRange(reducedValues).getTypes(), ValueRange{batch}, + ArrayRef{ + rewriter.getNamedAttr("lowerBoundsGroups", + rewriter.getI32TensorAttr({1})), + rewriter.getNamedAttr("upperBoundsGroups", + rewriter.getI32TensorAttr({1})), + rewriter.getNamedAttr( + "lowerBoundsMap", + AffineMapAttr::get(AffineMap::get(0, 0, {zeroAffine}, + rewriter.getContext()))), + rewriter.getNamedAttr("upperBoundsMap", + AffineMapAttr::get(AffineMap::get( + 1, 0, {d0}, rewriter.getContext()))), + rewriter.getNamedAttr("reductions", rewriter.getArrayAttr({})), + rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({1}))}); + + // Create the loop body for the parallel loop. + Block *loopBody = new Block(); + rewriter.setInsertionPointToStart(loopBody); + loopBody->addArgument(rewriter.getIndexType(), loc); + Value loopVarBatchIdx = loopBody->getArguments()[0]; + + // Prefetching data from tensor 'A' for better cache utilization. + rewriter.create( + loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()), + ArrayRef{loopVarBatchIdx, M, K}, false, 3, true); + + // build loop body + affine::buildAffineLoopNest( + rewriter, loc, {c0}, {N}, kNLen, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + auto ivJ = ivRange.front(); + affine::buildAffineLoopNest( + builder, loc, {c0}, {M}, kernelM, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivI = ivRange.front(); + SmallVector cptrs; + + const VectorType vTy = + VectorType::get(vecSize, ATy.getElementType()); + + for (int i = 0; i < kernelM; i++) { + Value fixedIV = builder.create( + loc, + AffineMap::get(1, 1, {d0 + i, s0 - 1}, + builder.getContext()), + SmallVector{ivI, M}); + MemRefType resTy = MemRefType::get( + ATy.getShape(), ATy.getElementType(), + AffineMap::get(3, 3, d1 * s2 + d0 * s1 + s0 + d2)); + auto cptr = builder.create( + loc, resTy, C, + SmallVector{loopVarBatchIdx, fixedIV, c0}, + SmallVector{c1, c1, N}, + SmallVector{c1, c1, c1}); + cptrs.push_back(cptr); + } + affine::buildAffineLoopNest( + builder, loc, {c0}, {K}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivK = ivRange.front(); + SmallVector bs; + + for (int j = 0; j < kernelN; j++) { + Value fixedJV = ivJ; + if (j != 0) { + fixedJV = builder.create( + loc, AffineMap::get(1, 0, d0 + j * vecSize), ivJ); + } + bs.push_back(builder.create( + loc, vTy, B, + ValueRange{loopVarBatchIdx, ivK, fixedJV})); + } + + for (int i = 0; i < kernelM; ++i) { + Value fixedIV = ivI; + if (i != 0) { + fixedIV = builder.create( + loc, + AffineMap::get(1, 0, {d0 + i}, + builder.getContext()), + SmallVector{ivI}); + } + affine::AffineIfOp mBranchingOp = + builder.create( + loc, + IntegerSet::get(1, 1, {-d0 + s0 - 1}, {false}), + ValueRange{fixedIV, M}, false); + OpBuilder mTrueBranchBuilder = + mBranchingOp.getThenBodyBuilder(); + Value ksubAElement = + mTrueBranchBuilder.create( + loc, A, + ValueRange{loopVarBatchIdx, fixedIV, ivK}); + + for (int j = 0; j < kernelN; j++) { + Value fixedJV = ivJ; + if (j != 0) { + fixedJV = + mTrueBranchBuilder + .create( + loc, + AffineMap::get(1, 0, d0 + j * vecSize), + ivJ); + } + Value vecC = mTrueBranchBuilder.create( + loc, vTy, cptrs[i], ValueRange{c0, c0, fixedJV}); + if (isa(elementType)) { + Value vecA = + mTrueBranchBuilder.create( + loc, vTy, ksubAElement); + Value vecMul = + mTrueBranchBuilder.create( + loc, vTy, vecA, bs[j]); + vecC = mTrueBranchBuilder.create( + loc, vTy, vecMul, vecC); + } else { + Value vecA = + mTrueBranchBuilder.create( + loc, vTy, ksubAElement); + vecC = mTrueBranchBuilder.create( + loc, vTy, vecA, bs[j], vecC); + } + // store vecC + Value tailLength = + mTrueBranchBuilder.create( + loc, AffineMap::get(2, 0, -d0 + d1), + ValueRange{fixedJV, N}); + affine::AffineIfOp nBranchingOp = + mTrueBranchBuilder.create( + loc, + IntegerSet::get(1, 0, {-vecSize + d0}, + {false}), + ValueRange{tailLength}, true); + // Calculate the length of the tail, which might not + // fit in a vector. + OpBuilder nTrueBranchBuilder = + nBranchingOp.getThenBodyBuilder(); + nTrueBranchBuilder.create( + loc, vecC, cptrs[i], ValueRange{c0, c0, fixedJV}); + OpBuilder nFalseBranchBuilder = + nBranchingOp.getElseBodyBuilder(); + // Generate a mask vector based on the tail length. + Value maskVector = + nFalseBranchBuilder.create( + loc, + VectorType::get({vecSize}, + rewriter.getI1Type()), + ValueRange{tailLength}); + nFalseBranchBuilder.create( + loc, cptrs[i], ValueRange{c0, c0, fixedJV}, + maskVector, vecC); + } + } + }); + }); + }); + + rewriter.create(loc); + + // Finalize the loop and erase the original operation. + parallelBatchLoop.getRegion().push_back(loopBody); + rewriter.setInsertionPointAfter(parallelBatchLoop); + + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BatchMatMulTileOptimizePass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class BatchMatMulTileOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BatchMatMulTileOptimizePass) + StringRef getArgument() const final { return "batchmatmul-tile-optimize"; } + StringRef getDescription() const final { + return "BatchMatMul Tile Optimization."; + } + BatchMatMulTileOptimizePass() = default; + BatchMatMulTileOptimizePass(const BatchMatMulTileOptimizePass &) {} + explicit BatchMatMulTileOptimizePass(int64_t vecSizeParam, + int64_t kernelMParam, + int64_t kernelNParam) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(16)}; + + Option kernelM{*this, "kernel-m", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(4)}; + + Option kernelN{*this, "kernel-n", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(2)}; +}; +} // end anonymous namespace. + +void BatchMatMulTileOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize, kernelM, + kernelN); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} +// add to buddy-opt.cpp +namespace mlir { +namespace buddy { +void registerBatchMatMulTileOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index 8e726863e..2803af674 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -1,8 +1,10 @@ add_mlir_library(MatMulOptimization - BatchMatMulOptimize.cpp MatMulOptimize.cpp MatMulVectorization.cpp MatMulParallelVectorization.cpp + BatchMatMulOptimize.cpp + BatchMatMulTileOptimize.cpp + BatchMatMulSCFOptimize.cpp LINK_LIBS PUBLIC BuddyUtils ) @@ -11,6 +13,14 @@ add_mlir_library(BatchMatMulOptimization BatchMatMulOptimize.cpp ) +add_mlir_library(BatchMatMulTileOptimization + BatchMatMulTileOptimize.cpp +) + +add_mlir_library(BatchMatMulSCFOptimization + BatchMatMulSCFOptimize.cpp +) + add_mlir_library(MatMulParallelVectorization MatMulParallelVectorization.cpp ) diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 24bcde935..94109d28d 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -26,9 +26,12 @@ target_link_libraries(buddy-opt LowerRVVPass MatMulOptimization BatchMatMulOptimization + BatchMatMulTileOptimization + BatchMatMulSCFOptimization MatMulParallelVectorization TransposeOptimization ConvOptimization + DepthwiseConvOptimization VectorExp LowerVectorExpPass BuddyGemmini diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index bea9513b5..a40fda18f 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -40,31 +40,37 @@ #include "DAP/DAPOps.h" #include "DIP/DIPDialect.h" #include "DIP/DIPOps.h" -#include "RVV/RVVDialect.h" -#include "VectorExp/VectorExpDialect.h" -#include "VectorExp/VectorExpOps.h" #include "Gemmini/GemminiDialect.h" #include "Gemmini/GemminiOps.h" +#include "RVV/RVVDialect.h" #include "Sche/ScheDialect.h" #include "Sche/ScheOps.h" +#include "VectorExp/VectorExpDialect.h" +#include "VectorExp/VectorExpOps.h" namespace mlir { namespace buddy { void registerConvVectorizationPass(); void registerPointwiseConvToGemmPass(); +void registerPointwiseConvToGemmForNhwcFhwcPass(); void registerPoolingVectorizationPass(); void registerLowerBudPass(); void registerLowerDIPPass(); +void registerBatchMatMulOptimizePass(); +void registerBatchMatMulTileOptimizePass(); +void registerBatchMatMuSCFOptimize(); void registerLowerDAPPass(); void registerExtendDAPPass(); void registerDAPVectorizePass(); void registerLowerRVVPass(); -void registerBatchMatMulOptimizePass(); void registerMatMulOptimizePass(); void registerMatMulVectorizationPass(); void registerMatMulParallelVectorizationPass(); void registerTransposeOptimizationPass(); void registerConvOptimizePass(); +void registerConvNhwcFhwcOptimizePass(); +void registerConvNhwcFhwcTileOptimizePass(); +void registerDepthwiseConv2DNhwcHwcOptimizePass(); void registerLowerVectorExpPass(); void registerLowerGemminiPass(); void registerLowerLinalgToGemminiPass(); @@ -78,6 +84,7 @@ int main(int argc, char **argv) { // Register all MLIR passes. mlir::registerAllPasses(); mlir::buddy::registerPointwiseConvToGemmPass(); + // mlir::buddy::registerPointwiseConvToGemmForNhwcFhwcPass(); // Register Vectorization of Convolution. mlir::buddy::registerConvVectorizationPass(); // Register Vectorization of Pooling. @@ -95,11 +102,16 @@ int main(int argc, char **argv) { // Register Several Optimize Pass. mlir::buddy::registerMatMulOptimizePass(); + mlir::buddy::registerBatchMatMulOptimizePass(); + mlir::buddy::registerBatchMatMulTileOptimizePass(); + mlir::buddy::registerBatchMatMuSCFOptimize(); mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerMatMulParallelVectorizationPass(); - mlir::buddy::registerBatchMatMulOptimizePass(); mlir::buddy::registerTransposeOptimizationPass(); mlir::buddy::registerConvOptimizePass(); + mlir::buddy::registerConvNhwcFhwcOptimizePass(); + mlir::buddy::registerConvNhwcFhwcTileOptimizePass(); + mlir::buddy::registerDepthwiseConv2DNhwcHwcOptimizePass(); mlir::buddy::registerDeviceSchedulePass(); mlir::buddy::registerLowerSchePass(); mlir::buddy::registerFuncBufferizeDynamicOffsetPass();