diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index dc458d39e0..c6d405a30e 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" @@ -215,6 +216,10 @@ void addPassesNNPA(mlir::OwningOpRef &module, addKrnlToAffinePasses(pm); // Optimizations at ZLow that needs affine map in MemRef. pm.addPass(zlow::createZLowRewritePass()); + // Late generation of code for stick/unstick, needed to be after a + // ZLowRewrite pass. + if (nnpaEnableCompilerStickUnstick) + pm.addPass(zlow::createZLowStickExpansionPass(enableParallel)); pm.addPass(mlir::createCanonicalizerPass()); // Normalize MemRefs. normalizeMemRefsPasses(pm); @@ -223,6 +228,11 @@ void addPassesNNPA(mlir::OwningOpRef &module, addKrnlToAffinePasses(pm); // Optimizations at ZLow after normalizing MemRefs. pm.addPass(zlow::createZLowRewritePass()); + // The createZLowStickExpansion pass may create parallel constructs, + // they need to be handled here. + if (nnpaEnableCompilerStickUnstick && enableParallel) + pm.addPass(mlir::createConvertSCFToOpenMPPass()); + pm.addPass(mlir::createCanonicalizerPass()); // Constant folding for std.alloc. pm.addNestedPass(onnx_mlir::createFoldStdAllocPass()); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp index 9f8c44efdd..a46ef6746a 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp @@ -981,3 +981,11 @@ bool isSuitableForZDNN( return true; } + +/// Check legality for ONNXReshapeOp. +template <> +bool isSuitableForZDNN( + ONNXReshapeOp op, const DimAnalysis *dimAnalysis) { + // Noop Reshape is suitable for zAIU as this pass removes such reshape ops. + return isIdentityReshape(op, dimAnalysis); +} diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp index c2aac98faf..943c577181 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp @@ -34,6 +34,7 @@ #include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" #include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp" #include "src/Support/TypeUtilities.hpp" @@ -467,6 +468,31 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern { } }; +class RemoveReshapeWithIdentityPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + DimAnalysis *dimAnalysis; + + RemoveReshapeWithIdentityPattern( + MLIRContext *context, DimAnalysis *dimAnalysis) + : OpRewritePattern(context, 1001), + dimAnalysis(dimAnalysis) {} + + LogicalResult matchAndRewrite( + ONNXReshapeOp reshapeOp, PatternRewriter &rewriter) const override { + if (!isIdentityReshape(reshapeOp, dimAnalysis)) + return failure(); + + // Rewrite + Operation *op = reshapeOp.getOperation(); + Value data = reshapeOp.getData(); + rewriter.replaceOp(op, data); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Rewrite ONNX ops to ZHigh ops and ONNX ops for ZHigh. //===----------------------------------------------------------------------===// @@ -482,6 +508,8 @@ void getRewriteONNXForZHighPatterns( patterns.getContext(), dimAnalysis); patterns.insert>( patterns.getContext(), dimAnalysis); + patterns.insert( + patterns.getContext(), dimAnalysis); } void getRewriteONNXForZHighDynamicallyLegal( @@ -643,6 +671,13 @@ void getRewriteONNXForZHighDynamicallyLegal( return isSuitableForZDNN(op) || !canInferencePadsForNNPAConv(op); }); + addDynamicallyLegalOpFor(target, dimAnalysis, + [](ONNXReshapeOp op, const DimAnalysis *dimAnalysis) { + // Get rid of identity reshape here, as it impacts stick/unstick. + // So all reshape are legal, unless it is an identity reshape, in which + // case there is a rule here to remove it. + return !isIdentityReshape(op, dimAnalysis); + }); } struct RewriteONNXForZHighPass diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index 0f5eeb8aba..98e7235f1d 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -32,11 +32,7 @@ #include "src/Support/TypeUtilities.hpp" #define DEBUG_TYPE "zhigh-to-zlow" -#define ENABLE_CSU_PAR true /* Allow parallel compiler gen Stick/Unstick. */ -#define CS_N 2 /* Tiling for Stick */ -#define CS_M 2 /* Tiling for Stick */ -#define CU_N 2 /* Tiling for Unstick */ -#define CU_M 2 /* Tiling for Unstick */ + using namespace mlir; using namespace onnx_mlir::zlow; @@ -497,14 +493,9 @@ ZMemRefType convertZTensorToMemRefType(Type type) { // Support for flatten ztensor struct ZHighToZLowStickOpLowering : public ConversionPattern { - ZHighToZLowStickOpLowering(TypeConverter &typeConverter, MLIRContext *ctx, - bool enableParallel, bool enableCompilerStickUnstickCodeGen) + ZHighToZLowStickOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) : ConversionPattern( - typeConverter, ZHighStickOp::getOperationName(), 1, ctx), - enableParallel(enableParallel), - enableCompilerCodeGen(enableCompilerStickUnstickCodeGen) {} - bool enableParallel; - bool enableCompilerCodeGen; + typeConverter, ZHighStickOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -529,201 +520,11 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern { if (isNHWCLayout(layout)) layout = getNCHWLayoutAttr(rewriter); - if (enableCompilerCodeGen) { - // Generic way to handle all formats listed below. - // Think we only come in here when condition below is true. - if (layout.getValue().equals_insensitive("4D") || - layout.getValue().equals_insensitive("3D") || - layout.getValue().equals_insensitive("2D") || - // layout.getValue().equals_insensitive("1D") || - // layout.getValue().equals_insensitive("2DS") || - layout.getValue().equals_insensitive("3DS")) { - return generateStickCode( - rewriter, op, shapeHelper, alloc, input, layout); - } - } - // Else, emit a ZLow operation. rewriter.create(loc, input, alloc, layout); rewriter.replaceOp(op, alloc); return success(); } - - /* Generic version that create code for all normal layouts. Code could work - for 2DS and 1D also, but at this time there is an issue with the affine - maps of these formats (see issue #1940), so disable for the moment being. - If enabled, it would need to be tested again. - */ - LogicalResult generateStickCode(ConversionPatternRewriter &rewriter, - Operation *op, ZHighStickOpShapeHelper &shapeHelper, Value alloc, - Value input, StringAttr layout) const { - Location loc = op->getLoc(); - MDBuilder create(rewriter, loc); - - // Compute output dims and rank. - IndexExprScope allocScope(create.krnl, shapeHelper.getScope()); - DimsExpr outputDims; - getIndexExprList(shapeHelper.getOutputDims(), outputDims); - int64_t rank = outputDims.size(); - bool is2DS = layout.getValue().equals_insensitive("2DS"); - - // Tiling for Stick in the E2 x E1 dim: N x 64M. - int64_t N = CS_N; - int64_t M = CS_M; - if (rank == 1 || is2DS) - N = 1; // No tiling on E2 dim for - assert(32 % N == 0 && "Tiling by N (along E2) must divide 32"); - - // Info for SIMD Vector Length (VL) and associated types. - int64_t VL = 8; // FP16 VL. - int64_t VLHalf = VL / 2; // FP32 VL. - assert(64 % VL == 0 && "SIMD vector length must divide 64"); - Type f16Type = rewriter.getF16Type(); - Type f32Type = rewriter.getF32Type(); - VectorType vecF32Type = VectorType::get({VLHalf}, f32Type); - - // Type for buffer (write N*64 continuously back in the output). - MemRefType bufferType = MemRefType::get({M, N, 64}, f16Type); - - // Define useful literals. - IndexExpr litZero = LiteralIndexExpr(0); - IndexExpr lit1 = LiteralIndexExpr(1); - IndexExpr litN = LiteralIndexExpr(N); - IndexExpr litM = LiteralIndexExpr(M); - IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); - IndexExpr lit64 = LiteralIndexExpr(64); - - // Useful references for indexing dimensions (neg val are not used). - int64_t E4 = rank - 4, E3 = rank - 3, E2 = rank - 2, E1 = rank - 1; - - // Create loop iterations. Note that we iterate over E1 as tiles of 64 - // elements. - ValueRange loopDefs = create.krnl.defineLoops(rank); - ValueRange tiledDefE1 = create.krnl.block(loopDefs[E1], M); - DimsExpr lbs(rank, litZero); - DimsExpr ubs = outputDims; - IndexExpr T1 = outputDims[E1].ceilDiv(64); - ubs[E1] = T1; // E1 dim is over tiles. - SmallVector optLoopDefs; - if (rank == 4) { - ValueRange tiledDefE2 = create.krnl.block(loopDefs[E2], N); - // 4D loop order: E4, E3, E2 tiled by N, E1 tiled by M, followed by E1 - // (inside M), then unused. - // clang-format off - create.krnl.permute( - {/*E4*/loopDefs[E4],/*E3*/loopDefs[E3],/*E2*/tiledDefE2[0],tiledDefE2[1],/*E1*/tiledDefE1[0],tiledDefE1[1]}, - {/*E4*/0, /*E3*/1, /*E2*/2, 5, /*E1*/3, 4}); - // clang-format on - optLoopDefs = {loopDefs[E4], loopDefs[E3], tiledDefE2[0], tiledDefE1[0]}; - } else if (rank == 3) { - ValueRange tiledDefE2 = create.krnl.block(loopDefs[E2], N); - // 3D/3DS loop order: E3, E2 tiled by N, E1 tiled by M, followed by E1 - // (inside M), then unused. Order does not change for 3D vs 3DS. - create.krnl.permute({/*E3*/ loopDefs[E3], - /*E2*/ tiledDefE2[0], tiledDefE2[1], - /*E1*/ tiledDefE1[0], tiledDefE1[1]}, - {/*E3*/ 0, /*E2*/ 1, 4, /*E1*/ 2, 3}); - optLoopDefs = {loopDefs[E3], tiledDefE2[0], tiledDefE1[0]}; - } else if (rank == 2) { - ValueRange tiledDefE2 = create.krnl.block(loopDefs[E2], N); - // 2D/2DS loop order: E2 tiled by N, E1 tiled by M, followed by E1 - // (inside M), then unused. Order does not change for 2D vs 2DS. 2DS has a - // E2 tile of 1... which we can safely ignore here. - create.krnl.permute({/*E2*/ tiledDefE2[0], tiledDefE2[1], - /*E1*/ tiledDefE1[0], tiledDefE1[1]}, - {/*E2*/ 0, 3, /*E1*/ 1, 2}); - optLoopDefs = {tiledDefE2[0], tiledDefE1[0]}; - } else if (rank == 1) { - // 1D loop order: E2 tiled by N, E1 tiled by M, followed by E1 - // (inside M), then unused. Order does not change for 2D vs 2DS. 2DS has a - // E2 tile of 1... which we can safely ignore here. - create.krnl.permute({/*E1*/ tiledDefE1[0], tiledDefE1[1]}, {/*E1*/ 0, 1}); - optLoopDefs = {tiledDefE1[0]}; - } else { - llvm_unreachable("rank 1 to 4 only"); - } - - // Parallel... - if (enableParallel) { - int64_t parId; - // TODO: may want to check if ub of rank makes sense here. - if (findSuitableParallelDimension(lbs, ubs, 0, rank, parId, 8)) { - create.krnl.parallel(optLoopDefs[parId]); - onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], - "compiler-generated stickify"); - } else { - onnxToKrnlParallelReport(op, false, -1, -1, - "no dim with enough work in compiler-generated stickify"); - } - } - - // Outer loop (E4, E3, E2 tiled by N, E1 tiled by M) - create.krnl.iterateIE(loopDefs, optLoopDefs, lbs, ubs, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope outerScope(create.krnl, &allocScope); - DimsExpr outerIndices; - getIndexExprList(loopInd, outerIndices); - // Create buffer (for parallel, must be inside loop). - Value buffer = create.mem.alignedAlloc(bufferType, {}); - // Iterate over N, M, and 64. Manage iterations explicitly. - DimsExpr lbs2(3, litZero); - DimsExpr ubs2 = {litN, litM, lit64}; - SmallVector steps2 = {1, 1, VL}; - // Analysis of assembly showed that the inner loop was fully unrolled. - create.affine.forIE( - lbs2, ubs2, steps2, [&](AffineBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - DimsExpr inputAF; - IndexExprScope innerScope(create.krnl, &outerScope); - SymbolIndexExpr n(loopInd[0]), m(loopInd[1]), l(loopInd[2]); - getIndexExprList(outerIndices, inputAF); - // If E2 is unrolled, must add the "n" local E2 offset. - if (rank > 1 && N > 1) - inputAF[E2] = inputAF[E2] + n; - // Translate the tile index t1 to the actual targetted data: e1 - // => 64 (e1+m) and add the "l" local E1 offset. - inputAF[E1] = ((inputAF[E1] + m) * 64) + l; - Value vecF32H = - create.vec.loadIE(vecF32Type, input, inputAF, {}); - Value vecF32L = create.vec.loadIE( - vecF32Type, input, inputAF, {litVLHalf.getValue()}); - Value vecF16 = rewriter.create( - loc, vecF32H, vecF32L); - create.vec.storeIE(vecF16, buffer, {m, n, l}, {}); - }); - // Perform copy: E1 Tiled by 64 (inside tile by M) - create.krnl.iterate({}, {tiledDefE1[1]}, {}, {}, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - DimsExpr outputAF; - IndexExprScope innerScope(create.krnl, &outerScope); - SymbolIndexExpr t1(loopInd[0]); - getIndexExprList(outerIndices, outputAF); - // Compute m, the current m * 64 tile being processed by this - // inner loop. - IndexExpr m = t1 - outputAF[E1]; - // E1 is tiled, multiply by 64 to get the tile start. - outputAF[E1] = t1 * 64; - Value allocOffset = - create.krnl.getLinearOffsetIndexIE(alloc, outputAF); - DimsExpr reallocTileDims = {litN, lit64}; - Value allocAsNx64 = create.mem.reinterpretCast( - alloc, litZero.getValue(), reallocTileDims); - // Calculate buffer offset - int64_t num = N * 64; - IndexExpr bufferOffset = m * num; - // Amount of values to copy - Type intType = rewriter.getIntegerType(64); - Value numVal = create.math.constant(intType, num); - // Mem copy - create.krnl.memcpy(allocAsNx64, buffer, numVal, allocOffset, - bufferOffset.getValue()); - }); - }); - rewriter.replaceOp(op, alloc); - return success(); - } }; //===----------------------------------------------------------------------===// @@ -804,14 +605,9 @@ struct ZHighToZLowStickForGRUOpLowering : public ConversionPattern { //===----------------------------------------------------------------------===// struct ZHighToZLowUnstickOpLowering : public ConversionPattern { - ZHighToZLowUnstickOpLowering(TypeConverter &typeConverter, MLIRContext *ctx, - bool enableParallel, bool enableCompilerStickUnstickCodeGen) + ZHighToZLowUnstickOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) : ConversionPattern( - typeConverter, ZHighUnstickOp::getOperationName(), 1, ctx), - enableParallel(enableParallel), - enableCompilerCodeGen(enableCompilerStickUnstickCodeGen) {} - bool enableParallel; - bool enableCompilerCodeGen; + typeConverter, ZHighUnstickOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -843,655 +639,11 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { if (isNHWCLayout(layout)) layout = getNCHWLayoutAttr(rewriter); - if (enableCompilerCodeGen) { - // Generic way to handle all formats listed below. - // Think we only come in here when condition below is true. - if (layout.getValue().equals_insensitive("4D") || - layout.getValue().equals_insensitive("3D") || - layout.getValue().equals_insensitive("2D") || - layout.getValue().equals_insensitive("3DS")) { - // Alternative versions of the code are: - // o generateUnstickCode (preferred) - // o generateUnstickCodeE1E2 (different loop order) - // o generateUnstickCodeSimple (simple) - // As the code matures, we will keep only one. - return generateUnstickCode( - rewriter, op, shapeHelper, alloc, input, layout); - } - } - // Emit a ZLow operation. rewriter.create(loc, input, alloc, layout); rewriter.replaceOp(op, alloc); return success(); } - - /* - This version tile E1 by 64, convert f16 to f32 64 values at a time. - Then it simply write individual f32 to its destination 1 value at a time. - */ - LogicalResult generateUnstickCodeSimple(ConversionPatternRewriter &rewriter, - Operation *op, ZHighUnstickOpShapeHelper &shapeHelper, Value alloc, - Value input, StringAttr layout) const { - Location loc = op->getLoc(); - MDBuilder create(rewriter, loc); - - // Compute output dims and rank. - IndexExprScope allocScope(create.krnl, shapeHelper.getScope()); - DimsExpr outputDims; - getIndexExprList(shapeHelper.getOutputDims(), outputDims); - int64_t rank = outputDims.size(); - - // Info for SIMD Vector Length (VL) and associated types. - int64_t VL = 8; // FP16 VL. - int64_t VLHalf = VL / 2; // FP32 VL. - assert(64 % VL == 0 && "SIMD vector length must divide 64"); - Type f16Type = rewriter.getF16Type(); - Type f32Type = rewriter.getF32Type(); - VectorType vecF16Type = VectorType::get({VL}, f16Type); - - // Type for buffer (write 64 continuously back in the output). - MemRefType bufferType = MemRefType::get({64}, f32Type); - - // Define useful literals. - IndexExpr litZero = LiteralIndexExpr(0); - IndexExpr lit1 = LiteralIndexExpr(1); - IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); - IndexExpr litVL = LiteralIndexExpr(VL); - IndexExpr lit64 = LiteralIndexExpr(64); - - // Useful references for indexing dimensions (neg val are not used). - int64_t E4 = rank - 4, E3 = rank - 3, E2 = rank - 2, E1 = rank - 1; - - // Create loop iterations. Note that we iterate over E1 as tiles of 64 - // elements. - ValueRange loopDefs = create.krnl.defineLoops(rank); - ValueRange tiledDefE1 = create.krnl.block(loopDefs[E1], 64); - DimsExpr lbs(rank, litZero); - DimsExpr ubs = outputDims; - SmallVector optLoopDefs; - if (rank == 4) { - // 4D loop order: E4, E3, E2, E1 tiled by64, followed by E1 - // clang-format off - create.krnl.permute( - {/*E4*/loopDefs[E4],/*E3*/loopDefs[E3],/*E2*/loopDefs[E2],/*E1*/tiledDefE1[0],tiledDefE1[1]}, - {/*E4*/0, /*E3*/1, /*E2*/2, /*E1*/3, 4}); - // clang-format on - optLoopDefs = {loopDefs[E4], loopDefs[E3], loopDefs[E2], tiledDefE1[0]}; - } else if (rank == 3) { - // 3D/3DS loop order: - // clang-format off - create.krnl.permute( - {/*E3*/loopDefs[E3],/*E2*/loopDefs[E2],/*E1*/tiledDefE1[0],tiledDefE1[1]}, - {/*E3*/0, /*E2*/1, /*E1*/2, 3}); - // clang-format on - optLoopDefs = {loopDefs[E3], loopDefs[E2], tiledDefE1[0]}; - } else if (rank == 2) { - // 2D/2DS loop order: - // clang-format off - create.krnl.permute( - {/*E2*/loopDefs[E2],/*E1*/tiledDefE1[0],tiledDefE1[1]}, - {/*E2*/0, /*E1*/1, 2}); - // clang-format on - optLoopDefs = {loopDefs[E2], tiledDefE1[0]}; - } else { - llvm_unreachable("rank 2 to 4 only"); - } - - // Parallel... - if (enableParallel) { - int64_t parId; - // TODO: may want to check if ub of rank makes sense here. - if (findSuitableParallelDimension(lbs, ubs, 0, rank, parId, 8)) { - create.krnl.parallel(optLoopDefs[parId]); - onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], - "compiler-generated stickify"); - } else { - onnxToKrnlParallelReport(op, false, -1, -1, - "no dim with enough work in compiler-generated stickify"); - } - } - - // Compute max tiles. It is actually not easy to compute the max number of - // tiles; since we don't allocate, just need to index by the "tile size", it - // is sufficient to assume 2 or more. - IndexExpr T = LiteralIndexExpr(2); - DimsExpr reallocTileDims = {T, lit64}; - Value inputAsTx64 = - create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims); - - // Outer loop (E4, E3, E2 tiled by N, E1 tiled by M) - create.krnl.iterateIE(loopDefs, optLoopDefs, lbs, ubs, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope outerScope(create.krnl, &allocScope); - DimsExpr outerIndices; - getIndexExprList(loopInd, outerIndices); - // Create buffer [64] (for parallel, must be inside loop). - Value buffer = create.mem.alignedAlloc(bufferType, {}); - Value inputOffset = - create.krnl.getLinearOffsetIndexIE(input, outerIndices); - IndexExpr inputDataOffset = SymbolIndexExpr(inputOffset); - IndexExpr inputTileOffset = inputDataOffset.floorDiv(lit64); - - create.affine.forIE( - litZero, lit64, VL, [&](AffineBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope innerScope(create.krnl, &outerScope); - DimIndexExpr l(loopInd[0]); - IndexExpr ii = SymbolIndexExpr(inputTileOffset); - Value vecF16 = - create.vec.loadIE(vecF16Type, inputAsTx64, {ii, l}, {}); - auto convertOp = - rewriter.create(loc, vecF16); - Value vecF32H = convertOp.getResult(0); - Value vecF32L = convertOp.getResult(1); - // Store f32 values back to the buffer[N][M][64]. - DimsExpr bufferAF = {l}; - create.vec.storeIE(vecF32H, buffer, bufferAF, {}); - create.vec.storeIE( - vecF32L, buffer, bufferAF, {litVLHalf.getValue()}); - }); - create.krnl.iterate({}, {tiledDefE1[1]}, {}, {}, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - DimsExpr outputAF; - IndexExprScope innerScope(create.krnl, &outerScope); - DimIndexExpr e1(loopInd[0]); - getIndexExprList(outerIndices, outputAF); - outputAF[E1] = e1; - IndexExpr l = e1 % lit64; - DimsExpr bufferAF = {l}; - Value t = create.krnl.loadIE(buffer, bufferAF); - create.krnl.storeIE(t, alloc, outputAF); - }); - }); - rewriter.replaceOp(op, alloc); - return success(); - } - - /* - Generic version that create code for all normal layouts but for 1D and 2DS. - This version tiles E2 and E1 by N and 64 M. - For conversions from 16 to 32, it process N * 64 values in a row - (effectively covering a partial Nx64 tile of the 32x64 tile). - Then, it memcpy M*64 values back to the output. Partial tiles are handled - one element at a time. - */ - LogicalResult generateUnstickCode(ConversionPatternRewriter &rewriter, - Operation *op, ZHighUnstickOpShapeHelper &shapeHelper, Value alloc, - Value input, StringAttr layout) const { - Location loc = op->getLoc(); - MDBuilder create(rewriter, loc); - - // Compute output dims and rank. - IndexExprScope allocScope(create.krnl, shapeHelper.getScope()); - DimsExpr outputDims; - getIndexExprList(shapeHelper.getOutputDims(), outputDims); - int64_t rank = outputDims.size(); - - // Tiling Unstick in the E2 x E1 dim: N x 64M. - int64_t N = CU_N; - int64_t M = CU_M; - assert(32 % N == 0 && "Tiling by N (along E2) must divide 32"); - - // Info for SIMD Vector Length (VL) and associated types. - int64_t VL = 8; // FP16 VL. - int64_t VLHalf = VL / 2; // FP32 VL. - assert(64 % VL == 0 && "SIMD vector length must divide 64"); - Type f16Type = rewriter.getF16Type(); - Type f32Type = rewriter.getF32Type(); - VectorType vecF16Type = VectorType::get({VL}, f16Type); - - // Type for buffer (write M*64 continuously back in the output). - MemRefType bufferType = MemRefType::get({N, M, 64}, f32Type); - - // Define useful literals. - IndexExpr litZero = LiteralIndexExpr(0); - IndexExpr lit1 = LiteralIndexExpr(1); - IndexExpr litN = LiteralIndexExpr(N); - IndexExpr litM = LiteralIndexExpr(M); - IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); - IndexExpr litVL = LiteralIndexExpr(VL); - IndexExpr lit64 = LiteralIndexExpr(64); - - // Useful references for indexing dimensions (neg val are not used). - int64_t E4 = rank - 4, E3 = rank - 3, E2 = rank - 2, E1 = rank - 1; - - // Create loop iterations. Note that we iterate over E1 as tiles of 64 - // elements. - ValueRange loopDefs = create.krnl.defineLoops(rank); - ValueRange tiledDefE2 = create.krnl.block(loopDefs[E2], N); - ValueRange tiledDefE1 = create.krnl.block(loopDefs[E1], M); - DimsExpr lbs(rank, litZero); - DimsExpr ubs = outputDims; - IndexExpr T1 = outputDims[E1].ceilDiv(64); - ubs[E1] = T1; // E1 dim is over tiles. - SmallVector optLoopDefs; - if (rank == 4) { - // 4D loop order: E4, E3, E2 tiled by N, E1 tiled by M, followed by E1 - // (inside M), then unused. - // clang-format off - create.krnl.permute( - {/*E4*/loopDefs[E4],/*E3*/loopDefs[E3],/*E2*/tiledDefE2[0],tiledDefE2[1],/*E1*/tiledDefE1[0],tiledDefE1[1]}, - {/*E4*/0, /*E3*/1, /*E2*/2, 4, /*E1*/3, 5}); - // clang-format on - optLoopDefs = {loopDefs[E4], loopDefs[E3], tiledDefE2[0], tiledDefE1[0]}; - } else if (rank == 3) { - // 3D/3DS loop order: E3, E2 tiled by N, E1 tiled by M, followed by E2 - // (inside N), then unused. Order does not change for 3D vs 3DS. - create.krnl.permute({/*E3*/ loopDefs[E3], - /*E2*/ tiledDefE2[0], tiledDefE2[1], - /*E1*/ tiledDefE1[0], tiledDefE1[1]}, - {/*E3*/ 0, /*E2*/ 1, 3, /*E1*/ 2, 4}); - optLoopDefs = {loopDefs[E3], tiledDefE2[0], tiledDefE1[0]}; - } else if (rank == 2) { - // 2D/2DS loop order: E2 tiled by N, E1 tiled by M, followed by E2 - // (inside N), then unused. Order does not change for 2D vs 2DS. 2DS has a - // E2 tile of 1... which we can safely ignore here. - create.krnl.permute({/*E2*/ tiledDefE2[0], tiledDefE2[1], - /*E1*/ tiledDefE1[0], tiledDefE1[1]}, - {/*E2*/ 0, 2, /*E1*/ 1, 3}); - optLoopDefs = {tiledDefE2[0], tiledDefE1[0]}; - } else { - llvm_unreachable("rank 2 to 4 only"); - } - - // Parallel... - if (enableParallel) { - int64_t parId; - // TODO: may want to check if ub of rank makes sense here. - if (findSuitableParallelDimension(lbs, ubs, 0, rank, parId, 8)) { - create.krnl.parallel(optLoopDefs[parId]); - onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], - "compiler-generated stickify"); - } else { - onnxToKrnlParallelReport(op, false, -1, -1, - "no dim with enough work in compiler-generated stickify"); - } - } - - // Compute max tiles. It is actually not easy to compute the max number of - // tiles. Since we don't allocate, it is just a "view", we only need to - // index by the "tile size", it is sufficient to assume 2 or more. Tiles are - // N x 64. - IndexExpr T = LiteralIndexExpr(2); - DimsExpr reallocTileDims = {T, litN, lit64}; - Value inputAsTxNx64 = - create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims); - - // Outer loop (E4, E3, E2 tiled by N, E1 tiled by M) - create.krnl.iterateIE(loopDefs, optLoopDefs, lbs, ubs, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope outerScope(create.krnl, &allocScope); - DimsExpr outerIndices; - getIndexExprList(loopInd, outerIndices); - // Create buffer [N][M][64] (for parallel, must be inside loop). - Value buffer = create.mem.alignedAlloc(bufferType, {}); - // Iterate over M, N, and 64. Manage iterations explicitly. - // Analysis of assembly showed that the inner loop was fully unrolled. - create.affine.forIE( // M - litZero, litM, 1, [&](AffineBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - DimsExpr inputAF; - IndexExprScope innerScope(create.krnl, &outerScope); - DimIndexExpr m(loopInd[0]); - getIndexExprList(outerIndices, inputAF); - // Translate the tile index t1 to the actual targetted data: e1 - // => 64 (e1+m). Don't use n & l in inputAF as we are mapping a - // [m=1][N][64] chunk of memory from the input. - inputAF[E1] = ((inputAF[E1] + m) * 64); - // May have to migrate this out. It is constant in Ms. - // e2 is by N, e1 is by 64. - Value inputOffset = - create.krnl.getLinearOffsetIndexIE(input, inputAF); - IndexExpr inputDataOffset = SymbolIndexExpr(inputOffset); - IndexExpr inputTileOffset = inputDataOffset.floorDiv(N * 64); - DimsExpr lbs2(2, litZero); - DimsExpr ubs2 = {litN, lit64}; - SmallVector steps2 = {1, VL}; - create.affine.forIE(lbs2, ubs2, steps2, // N, 64 - [&](AffineBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope innermostScope(create.krnl, &innerScope); - DimIndexExpr n(loopInd[0]), l(loopInd[1]); - Value vecF16 = - create.vec.loadIE(vecF16Type, inputAsTxNx64, - {SymbolIndexExpr(inputTileOffset), n, l}, {}); - auto convertOp = - rewriter.create( - loc, vecF16); - Value vecF32H = convertOp.getResult(0); - Value vecF32L = convertOp.getResult(1); - // Store f32 values back to the buffer[N][M][64]. - DimsExpr bufferAF = {n, SymbolIndexExpr(m), l}; - create.vec.storeIE(vecF32H, buffer, bufferAF, {}); - create.vec.storeIE( - vecF32L, buffer, bufferAF, {litVLHalf.getValue()}); - }); - }); - // Perform copy: E2 Tiled by N (inside tile by M); will copy here - // chunks of m * 64 values for a given n. - create.krnl.iterate({}, {tiledDefE2[1]}, {}, {}, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - DimsExpr outputAF; - IndexExprScope innerScope(create.krnl, &outerScope); - DimIndexExpr e2(loopInd[0]); - getIndexExprList(outerIndices, outputAF); - // Compute n, the current m * 64 tile being processed by this - // inner loop. - IndexExpr n = e2 - outputAF[E2]; - // E1 is tiled, multiply by 64 to get the tile start. - IndexExpr e1 = outputAF[E1] * 64; - - // I may process here up to [e1 ... e1 + m*64), make sure its - // not going out of bound, i.e. beyond outputDIms[E1]; - IndexExpr ub1 = SymbolIndexExpr(outputDims[E1]); - IndexExpr lit64M = LiteralIndexExpr(64 * M); - IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64M, ub1); - IndexExpr isFullLogical = isFull >= 0; - create.scf.ifThenElse( - // Condition - isFullLogical.getValue(), - // Then (is full). - [&](SCFBuilder b) { - MDBuilder create(b); - DimsExpr outputAF; - IndexExprScope innermostScope(create.krnl, &innerScope); - getIndexExprList(outerIndices, outputAF); - SymbolIndexExpr nn(n), ee1(e1), ee2(e2); - // Has full M*64 tiles here, use memcpy. - // Calculate offset for output (alloc) - outputAF[E2] = ee2; - outputAF[E1] = ee1; - Value allocOffset = - create.krnl.getLinearOffsetIndexIE(alloc, outputAF); - // Calculate buffer offset: buffer is [N][M][64] - int64_t num = M * 64; - IndexExpr bufferOffset = nn * num; - // Amount of values to copy - Type intType = rewriter.getIntegerType(64); - Value numVal = create.math.constant(intType, num); - // Mem copy - create.krnl.memcpy(alloc, buffer, numVal, allocOffset, - bufferOffset.getValue()); - }, - // Else (is not full). - [&](SCFBuilder b) { - MDBuilder create(b); - Value lb1v = e1.getValue(); - Value ub1v = ub1.getValue(); - Value e2v = e2.getValue(); - Value nv = n.getValue(); - create.scf.forLoop( - lb1v, ub1v, 1, [&](SCFBuilder b, ValueRange loopInd) { - MDBuilder create(b); - SmallVector outputAF; - Value e1v = loopInd[0]; - // Compute access function for output. - IndexExpr::getValues(outerIndices, outputAF); - outputAF[E2] = e2v; - outputAF[E1] = e1v; - // Compute access function for buffer. - Value lv = create.math.rem(e1v, lit64.getValue()); - Value mv = create.math.sub(e1v, lb1v); - mv = create.math.floorDiv(mv, lit64.getValue()); - SmallVector bufferAF = {nv, mv, lv}; - Value t = create.krnl.load(buffer, bufferAF); - create.krnl.store(t, alloc, outputAF); - }); // For. - }); // Else. - }); // Iterate over n. - }); - rewriter.replaceOp(op, alloc); - return success(); - } - - void swapE1E2(DimsExpr &array) const { - int64_t rank = array.size(); - IndexExpr t = array[rank - 1]; - array[rank - 1] = array[rank - 2]; - array[rank - 2] = t; - } - - /* - Same version as the previous one, but the outer loop is in E3, E1 by 64M, E2 - to better use the tiled data locality. - */ - LogicalResult generateUnstickCodeE1E2(ConversionPatternRewriter &rewriter, - Operation *op, ZHighUnstickOpShapeHelper &shapeHelper, Value alloc, - Value input, StringAttr layout) const { - Location loc = op->getLoc(); - MDBuilder create(rewriter, loc); - - // Compute output dims and rank. - IndexExprScope allocScope(create.krnl, shapeHelper.getScope()); - DimsExpr outputDims; - getIndexExprList(shapeHelper.getOutputDims(), outputDims); - int64_t rank = outputDims.size(); - - // Tiling Unstick in the E2 x E1 dim: N x 64M. - int64_t N = CU_N; - int64_t M = CU_M; - assert(32 % N == 0 && "Tiling by N (along E2) must divide 32"); - - // Info for SIMD Vector Length (VL) and associated types. - int64_t VL = 8; // FP16 VL. - int64_t VLHalf = VL / 2; // FP32 VL. - assert(64 % VL == 0 && "SIMD vector length must divide 64"); - Type f16Type = rewriter.getF16Type(); - Type f32Type = rewriter.getF32Type(); - VectorType vecF16Type = VectorType::get({VL}, f16Type); - - // Type for buffer (write M*64 continuously back in the output). - MemRefType bufferType = MemRefType::get({N, M, 64}, f32Type); - - // Define useful literals. - IndexExpr litZero = LiteralIndexExpr(0); - IndexExpr lit1 = LiteralIndexExpr(1); - IndexExpr litN = LiteralIndexExpr(N); - IndexExpr litM = LiteralIndexExpr(M); - IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); - IndexExpr litVL = LiteralIndexExpr(VL); - IndexExpr lit64 = LiteralIndexExpr(64); - - // Useful references for indexing dimensions (neg val are not used). - int64_t E4 = rank - 4, E3 = rank - 3, E2 = rank - 2, E1 = rank - 1; - - // Create loop iterations. Note that we iterate over E1 as tiles of 64 - // elements. - ValueRange loopDefs = create.krnl.defineLoops(rank); - ValueRange tiledDefE2 = create.krnl.block(loopDefs[E2], N); - ValueRange tiledDefE1 = create.krnl.block(loopDefs[E1], M); - DimsExpr lbs(rank, litZero); - DimsExpr ubs = outputDims; - IndexExpr T1 = outputDims[E1].ceilDiv(64); - ubs[E1] = T1; // E1 dim is over tiles. - SmallVector optLoopDefs; - if (rank == 4) { - // 4D loop order: E4, E3, E2 tiled by N, E1 tiled by M, followed by E1 - // (inside M), then unused. - // clang-format off - create.krnl.permute( - {/*E4*/loopDefs[E4],/*E3*/loopDefs[E3],/*E2*/tiledDefE2[0],tiledDefE2[1],/*E1*/tiledDefE1[0],tiledDefE1[1]}, - {/*E4*/0, /*E3*/1, /*E2*/3, 4, /*E1*/2, 5}); - // clang-format on - optLoopDefs = {loopDefs[E4], loopDefs[E3], tiledDefE1[0], tiledDefE2[0]}; - } else if (rank == 3) { - // 3D/3DS loop order: E3, E2 tiled by N, E1 tiled by M, followed by E2 - // (inside N), then unused. Order does not change for 3D vs 3DS. - create.krnl.permute({/*E3*/ loopDefs[E3], - /*E2*/ tiledDefE2[0], tiledDefE2[1], - /*E1*/ tiledDefE1[0], tiledDefE1[1]}, - {/*E3*/ 0, /*E2*/ 2, 3, /*E1*/ 1, 4}); - optLoopDefs = {loopDefs[E3], tiledDefE1[0], tiledDefE2[0]}; - } else if (rank == 2) { - // 2D/2DS loop order: E2 tiled by N, E1 tiled by M, followed by E2 - // (inside N), then unused. Order does not change for 2D vs 2DS. 2DS has a - // E2 tile of 1... which we can safely ignore here. - create.krnl.permute({/*E2*/ tiledDefE2[0], tiledDefE2[1], - /*E1*/ tiledDefE1[0], tiledDefE1[1]}, - {/*E2*/ 1, 2, /*E1*/ 0, 3}); - optLoopDefs = {tiledDefE1[0], tiledDefE2[0]}; - } else { - llvm_unreachable("rank 2 to 4 only"); - } - - // Parallel... - if (enableParallel) { - int64_t parId; - // TODO: may want to check if ub of rank makes sense here. - if (findSuitableParallelDimension(lbs, ubs, 0, rank, parId, 8)) { - create.krnl.parallel(optLoopDefs[parId]); - onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], - "compiler-generated stickify"); - } else { - onnxToKrnlParallelReport(op, false, -1, -1, - "no dim with enough work in compiler-generated stickify"); - } - } - - // Compute max tiles. It is actually not easy to compute the max number of - // tiles. Since we don't allocate, it is just a "view", we only need to - // index by the "tile size", it is sufficient to assume 2 or more. Tiles are - // N x 64. - IndexExpr T = LiteralIndexExpr(2); - DimsExpr reallocTileDims = {T, litN, lit64}; - Value inputAsTxNx64 = - create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims); - - // Outer loop (E4, E3, E2 tiled by N, E1 tiled by M) - create.krnl.iterateIE(loopDefs, optLoopDefs, lbs, ubs, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope outerScope(create.krnl, &allocScope); - DimsExpr outerIndices; - getIndexExprList(loopInd, outerIndices); - // Because of permute, swap E2 and E1 - swapE1E2(outerIndices); - // Create buffer [N][M][64] (for parallel, must be inside loop). - Value buffer = create.mem.alignedAlloc(bufferType, {}); - // Iterate over M, N, and 64. Manage iterations explicitly. - // Analysis of assembly showed that the inner loop was fully unrolled. - create.affine.forIE( // M - litZero, litM, 1, [&](AffineBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - DimsExpr inputAF; - IndexExprScope innerScope(create.krnl, &outerScope); - DimIndexExpr m(loopInd[0]); - getIndexExprList(outerIndices, inputAF); - // Translate the tile index t1 to the actual targetted data: e1 - // => 64 (e1+m). Don't use n & l in inputAF as we are mapping a - // [m=1][N][64] chunk of memory from the input. - inputAF[E1] = ((inputAF[E1] + m) * 64); - // May have to migrate this out. It is constant in Ms. - // e2 is by N, e1 is by 64. - Value inputOffset = - create.krnl.getLinearOffsetIndexIE(input, inputAF); - IndexExpr inputDataOffset = SymbolIndexExpr(inputOffset); - IndexExpr inputTileOffset = inputDataOffset.floorDiv(N * 64); - DimsExpr lbs2(2, litZero); - DimsExpr ubs2 = {litN, lit64}; - SmallVector steps2 = {1, VL}; - create.affine.forIE(lbs2, ubs2, steps2, // N, 64 - [&](AffineBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope innermostScope(create.krnl, &innerScope); - DimIndexExpr n(loopInd[0]), l(loopInd[1]); - Value vecF16 = - create.vec.loadIE(vecF16Type, inputAsTxNx64, - {SymbolIndexExpr(inputTileOffset), n, l}, {}); - auto convertOp = - rewriter.create( - loc, vecF16); - Value vecF32H = convertOp.getResult(0); - Value vecF32L = convertOp.getResult(1); - // Store f32 values back to the buffer[N][M][64]. - DimsExpr bufferAF = {n, SymbolIndexExpr(m), l}; - create.vec.storeIE(vecF32H, buffer, bufferAF, {}); - create.vec.storeIE( - vecF32L, buffer, bufferAF, {litVLHalf.getValue()}); - }); - }); - // Perform copy: E2 Tiled by N (inside tile by M); will copy here - // chunks of m * 64 values for a given n. - create.krnl.iterate({}, {tiledDefE2[1]}, {}, {}, - [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - DimsExpr outputAF; - IndexExprScope innerScope(create.krnl, &outerScope); - DimIndexExpr e2(loopInd[0]); - getIndexExprList(outerIndices, outputAF); - // Compute n, the current m * 64 tile being processed by this - // inner loop. - IndexExpr n = e2 - outputAF[E2]; - // E1 is tiled, multiply by 64 to get the tile start. - IndexExpr e1 = outputAF[E1] * 64; - - // I may process here up to [e1 ... e1 + m*64), make sure its - // not going out of bound, i.e. beyond outputDIms[E1]; - IndexExpr ub1 = SymbolIndexExpr(outputDims[E1]); - IndexExpr lit64M = LiteralIndexExpr(64 * M); - IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64M, ub1); - IndexExpr isFullLogical = isFull >= 0; - create.scf.ifThenElse( - // Condition - isFullLogical.getValue(), - // Then (is full). - [&](SCFBuilder b) { - MDBuilder create(b); - DimsExpr outputAF; - IndexExprScope innermostScope(create.krnl, &innerScope); - getIndexExprList(outerIndices, outputAF); - SymbolIndexExpr nn(n), ee1(e1), ee2(e2); - // Has full M*64 tiles here, use memcpy. - // Calculate offset for output (alloc) - outputAF[E2] = ee2; - outputAF[E1] = ee1; - Value allocOffset = - create.krnl.getLinearOffsetIndexIE(alloc, outputAF); - // Calculate buffer offset: buffer is [N][M][64] - int64_t num = M * 64; - IndexExpr bufferOffset = nn * num; - // Amount of values to copy - Type intType = rewriter.getIntegerType(64); - Value numVal = create.math.constant(intType, num); - // Mem copy - create.krnl.memcpy(alloc, buffer, numVal, allocOffset, - bufferOffset.getValue()); - }, - // Else (is not full). - [&](SCFBuilder b) { - MDBuilder create(b); - Value lb1v = e1.getValue(); - Value ub1v = ub1.getValue(); - Value e2v = e2.getValue(); - Value nv = n.getValue(); - create.scf.forLoop( - lb1v, ub1v, 1, [&](SCFBuilder b, ValueRange loopInd) { - MDBuilder create(b); - SmallVector outputAF; - Value e1v = loopInd[0]; - // Compute access function for output. - IndexExpr::getValues(outerIndices, outputAF); - outputAF[E2] = e2v; - outputAF[E1] = e1v; - // Compute access function for buffer. - Value lv = create.math.rem(e1v, lit64.getValue()); - Value mv = create.math.sub(e1v, lb1v); - mv = create.math.floorDiv(mv, lit64.getValue()); - SmallVector bufferAF = {nv, mv, lv}; - Value t = create.krnl.load(buffer, bufferAF); - create.krnl.store(t, alloc, outputAF); - }); // For. - }); // Else. - }); // Iterate over n. - }); - rewriter.replaceOp(op, alloc); - return success(); - } }; //===----------------------------------------------------------------------===// @@ -2578,17 +1730,15 @@ struct ZHighToZLowDataConversionLowering void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx, - bool enableParallel, bool enableCompilerStickUnstickCodeGen) { + bool enableParallel) { // Stickify and unstickify operations. patterns.insert(typeConverter, ctx); - patterns.insert(typeConverter, ctx, - ENABLE_CSU_PAR && enableParallel, enableCompilerStickUnstickCodeGen); + patterns.insert(typeConverter, ctx); patterns.insert(typeConverter, ctx); patterns.insert(typeConverter, ctx); patterns.insert( typeConverter, ctx); - patterns.insert(typeConverter, ctx, - ENABLE_CSU_PAR && enableParallel, enableCompilerStickUnstickCodeGen); + patterns.insert(typeConverter, ctx); patterns.insert>( typeConverter, ctx, /*fromF32=*/false, enableParallel); patterns.insert>( diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp index d232828368..5529c132b5 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp @@ -53,7 +53,7 @@ mlir::Value insertAllocForZMemRef(ZMemRefType zType, /// Populate all conversion patterns for ZHigh Ops. void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx, - bool enableParallel, bool enableCompilerStickUnstickCodeGen); + bool enableParallel); } // namespace zhigh } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index 775944d854..d0f1bc0456 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -101,6 +101,10 @@ void NNPAAccelerator::registerPasses(int optLevel) const { return onnx_mlir::zlow::createZLowRewritePass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return onnx_mlir::zlow::createZLowStickExpansionPass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return onnx_mlir::zlow::createZLowDummyOpForMultiDerefPass(); }); @@ -156,8 +160,8 @@ void NNPAAccelerator::conversionTargetONNXToKrnl( void NNPAAccelerator::rewritePatternONNXToKrnl( mlir::RewritePatternSet &patterns, mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const { - onnx_mlir::zhigh::populateZHighToZLowConversionPattern(patterns, - typeConverter, ctx, enableParallel, nnpaEnableCompilerStickUnstick); + onnx_mlir::zhigh::populateZHighToZLowConversionPattern( + patterns, typeConverter, ctx, enableParallel); } void NNPAAccelerator::conversionTargetKrnlToLLVM( diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index 9e25e44fa0..3a826345b4 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -60,6 +60,10 @@ namespace zlow { /// Add pass for rewriting ZLow ops. std::unique_ptr createZLowRewritePass(); +/// Add pass for rewriting ZLow ops. +std::unique_ptr createZLowStickExpansionPass( + bool enableParallel = false); + /// Add pass for rewriting ZLow ops. std::unique_ptr createZLowDummyOpForMultiDerefPass(); diff --git a/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt b/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt index cae09469b4..3710f0929b 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt +++ b/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt @@ -2,6 +2,7 @@ add_onnx_mlir_library(OMZLowRewrite ZLowRewrite.cpp + ZLowStickExpansion.cpp DEPENDS OMLayoutHelper @@ -12,8 +13,10 @@ add_onnx_mlir_library(OMZLowRewrite MLIRRewrite MLIRTransformUtils MLIRViewLikeInterface + OMONNXToKrnl OMZLowOps + ACCEL_INCLUDE_DIRS PRIVATE ${NNPA_INCLUDE_PATH} ) diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp index 331ca77012..f4a8a4b2fc 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp @@ -224,7 +224,7 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { /// #map3D> /// ``` /// `%stick` memref is unstickified and shuffled by the pair of (affine.load,affine.store), -/// then stickified again. It said data are transfered from a stickified memref +/// then stickified again. It said data are transferred from a stickified memref /// into another stickified memref via a chain of affine transformation. /// /// The above code can be rewritten into the following code: @@ -245,7 +245,7 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { /// maintain an affine map that maps one element in a memref to an element in /// another memref. Those maps are `#map2D` and `#map3D` in the above example. /// Combined with affine.load and affine.store, one element in a stickified -/// memref can be forwarded directly into an element in another stickifired +/// memref can be forwarded directly into an element in another stickified /// memref without `zlow.stick` and `zlow.unstick`. /// /// - The shape of the input and output memrefs of `zlow.stick`/`zlow.unstick` @@ -314,7 +314,7 @@ class UnstickLoadStoreStickRemovalPattern // incorrect: https://github.com/onnx/onnx-mlir/issues/1940 if ((unstickLayout == LAYOUT_1D) || (unstickLayout == LAYOUT_2DS)) return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { - diag << "Unsupport layout 1D and 2DS"; + diag << "Unsupported layout 1D and 2DS"; }); // 1. Match pattern: data flows from zlow.unstick to zlow.stick via diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp new file mode 100644 index 0000000000..1769292016 --- /dev/null +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -0,0 +1,482 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===--- ZLowStickExpansion.cpp - ZLow Stick/Unstick Expansion Patterns ---===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// This pass implements optimizations for ZLow operations, by substituting calls +// to stick / unstick with explict code to perform the transformation, when +// applicable. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Debug.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" +#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" +#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" +#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/Krnl/DialectBuilder.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Mlir/DialectBuilder.hpp" + +#include + +#define DEBUG_TYPE "zlow-stick-expansion" + +// Todo: cleanup after we are done experimenting. +#define ENABLE_CSU_PAR true /* Allow parallel compiler gen Stick/Unstick. */ +#define PREFETCH_CSU_DIST 0 +#define PREFETCH_CSU 1 + +using namespace mlir; + +namespace onnx_mlir { +namespace zlow { + +using MDBuilder = MultiDialectBuilder; + +/// Expand unstick operation to compiler generated code for suitable patterns, +/// aka all but the 1D and 2DS data layouts at this time. +class UnstickExpansionPattern : public OpRewritePattern { +public: + UnstickExpansionPattern(MLIRContext *context, bool enableParallelism = false) + : OpRewritePattern(context, 1), + enableParallel(enableParallelism) {} + + bool enableParallel = true; + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + ZLowUnstickOp unstickOp, PatternRewriter &rewriter) const override { + + // Generic way to handle all formats listed below. + StringAttr layout = unstickOp.getLayoutAttr(); + if (layout.getValue().equals_insensitive("4D") || + layout.getValue().equals_insensitive("3D") || + layout.getValue().equals_insensitive("2D") || + layout.getValue().equals_insensitive("3DS")) { + return generateUnstickCodeNoBuffer(rewriter, unstickOp, layout); + } + // Otherwise, we don't replace and keep the zdnn call. + return failure(); + } + + LogicalResult generateUnstickCodeNoBuffer(PatternRewriter &rewriter, + ZLowUnstickOp unstickOp, StringAttr layout) const { + Operation *op = unstickOp.getOperation(); + Location loc = unstickOp.getLoc(); + MDBuilder create(rewriter, loc); + IndexExprScope allocScope(create.krnl); + + // Compute output dims and rank. + Value input = unstickOp.getX(); + Value alloc = unstickOp.getOut(); + DimsExpr outputDims; + create.krnlIE.getShapeAsSymbols(alloc, outputDims); + int64_t rank = outputDims.size(); + + // Info for SIMD Vector Length (VL) and associated types. + int64_t VL = 8; // FP16 VL. + int64_t VLHalf = VL / 2; // FP32 VL. + assert(64 % VL == 0 && "SIMD vector length must divide 64"); + Type f16Type = rewriter.getF16Type(); + Type f32Type = rewriter.getF32Type(); + VectorType vecF16Type = VectorType::get({VL}, f16Type); + MemRefType bufferType = MemRefType::get({VL}, f32Type); + + // Define useful literals. + IndexExpr litZero = LiteralIndexExpr(0); + IndexExpr lit1 = LiteralIndexExpr(1); + IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); + IndexExpr litVL = LiteralIndexExpr(VL); + IndexExpr lit64 = LiteralIndexExpr(64); + + // Useful references for indexing dimensions (neg val are not used). + int64_t E1 = rank - 1; + + // Create loop iterations. Note that we iterate over E1 as tiles of 64 + // elements. + ValueRange loopDefs = create.krnl.defineLoops(rank); + DimsExpr lbs(rank, litZero); + DimsExpr ubs = outputDims; + IndexExpr T1 = outputDims[E1].ceilDiv(64); + ubs[E1] = T1; // E1 dim is over tiles. + + // Parallel... + if (enableParallel) { + int64_t parId; + // TODO: may want to check if ub of rank makes sense here. + if (findSuitableParallelDimension(lbs, ubs, 0, rank, parId, 8)) { + create.krnl.parallel(loopDefs[parId]); + onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], + "compiler-generated stickify"); + } else { + onnxToKrnlParallelReport(op, false, -1, -1, + "no dim with enough work in compiler-generated stickify"); + } + } + + // Compute max tiles. It is actually not easy to compute the max number of + // tiles. Since we don't allocate, it is just a "view", we only need to + // index by the "tile size", it is sufficient to assume 2 or more. Tiles are + // 64. + IndexExpr T = LiteralIndexExpr(2); + DimsExpr reallocTileDims = {T, lit64}; + Value inputAsTx64 = + create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims); + + // Outer loop (E4, E3, E2, E1 iterates over tiles of 64 elements) + create.krnl.iterateIE( + loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) { + MDBuilder create(b); + IndexExprScope outerScope(create.krnl, &allocScope); + DimsExpr outerIndices = DimListIE(loopInd); + // Computation for reading inputs. + DimsExpr inputAF = outerIndices; + IndexExpr e1 = outerIndices[E1] * 64; + inputAF[E1] = e1; + // Translate the tile index t1 to the actual targetted data. + Value inputOffset = + create.krnl.getLinearOffsetIndexIE(input, inputAF); + IndexExpr inputDataOffset = SymbolIndexExpr(inputOffset); + IndexExpr inputTileOffset = inputDataOffset.floorDiv(64); + +// Prefetch +#if PREFETCH_CSU + DimsExpr prefetchAF = inputAF; + // Prefetch current line + create.krnl.prefetchIE(input, prefetchAF, /*isWrite*/ false, + /*locality*/ 1); + create.krnl.prefetchIE(alloc, prefetchAF, /*isWrite*/ true, + /*locality*/ 1); +#if PREFETCH_CSU_DIST > 0 + // Prefetch line in advance. + prefetchAF[E1] = prefetchAF[E1] + (PREFETCH_CSU_DIST * 64); + create.krnl.prefetchIE(input, prefetchAF, /*isWrite*/ false, + /*locality*/ 1); + create.krnl.prefetchIE(alloc, prefetchAF, /*isWrite*/ true, + /*locality*/ 1); +#endif +#endif + + // I may process here up to [e1 ... e1 + m*64), make sure its + // not going out of bound, i.e. beyond outputDIms[E1]; + IndexExpr ub1 = SymIE(outputDims[E1]); + IndexExpr lit64Bis = LiteralIndexExpr(64); + IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1); + IndexExpr isFullLogical = isFull >= 0; + create.scf.ifThenElse( + // Condition + isFullLogical.getValue(), + // Then (is full). + [&](SCFBuilder b) { + MDBuilder create(b); + // Loop + const int64_t U = 4; + assert(U * VL <= 64 && "bad unroll"); + create.scf.forLoop(litZero.getValue(), lit64.getValue(), U * VL, + [&](SCFBuilder b, Value loopIndex) { + MDBuilder create(b); + IndexExprScope innerScope(b, &outerScope); + IndexExpr l = DimIE(loopIndex); + Value vecF16[U], vecF32H[U], vecF32L[U]; + // Load f16 values from input via reinterpreted data tile. + for (int64_t i = 0; i < U; ++i) { + vecF16[i] = create.vec.loadIE(vecF16Type, inputAsTx64, + {SymIE(inputTileOffset), l + (i * VL)}, {}); + } + // Convert back to f32. + for (int64_t i = 0; i < U; ++i) { + auto convertOp = + rewriter.create( + loc, vecF16[i]); + vecF32H[i] = convertOp.getResult(0); + vecF32L[i] = convertOp.getResult(1); + } + // Store f32 values back to the (normal layout) output. + DimsExpr outputAF = SymListIE(inputAF); + outputAF[E1] = outputAF[E1] + l; + for (int64_t i = 0; i < U; ++i) { + LiteralIndexExpr iH(i * VL), iL(i * VL + VL / 2); + create.vec.storeIE( + vecF32H[i], alloc, outputAF, {iH.getValue()}); + create.vec.storeIE( + vecF32L[i], alloc, outputAF, {iL.getValue()}); + } + }); + }, + // else, we don't have a full (64 e1) tile. + [&](SCFBuilder b) { + MDBuilder create(b); + IndexExprScope middleScope(b, &outerScope); + IndexExpr tripCount = SymIE(ub1) - SymIE(e1); + // Note: if we only have multiple of VL, loop below will handle + // all as we subtract (VL-1). Aka if VL=8 and tripCount = 16, + // tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we iterate + // over i=0 & i=8 as both are < 9. + IndexExpr tripCountWithoutPartialLastVL = tripCount - (VL - 1); + create.scf.forLoop(litZero.getValue(), + tripCountWithoutPartialLastVL.getValue(), VL, + [&](SCFBuilder b, Value loopIndex) { + MDBuilder create(b); + IndexExprScope innerScope(b, &middleScope); + IndexExpr l = DimIE(loopIndex); + // Load f16 values from input via reinterpreted data tile. + Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64, + {SymIE(inputTileOffset), l}, {}); + // Convert back to f32. + auto convertOp = + rewriter.create( + loc, vecF16); + Value vecF32H = convertOp.getResult(0); + Value vecF32L = convertOp.getResult(1); + // Store f32 values back to the (normal layout) output. + DimsExpr outputAF = SymListIE(inputAF); + outputAF[E1] = outputAF[E1] + l; + create.vec.storeIE(vecF32H, alloc, outputAF, {}); + create.vec.storeIE( + vecF32L, alloc, outputAF, {litVLHalf.getValue()}); + }); + // Deal with the last values: compute f32 using simd. + IndexExpr remainingScalarValues = tripCount % VL; + IndexExpr lastL = tripCount - remainingScalarValues; + Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64, + {SymIE(inputTileOffset), lastL}, {}); + // Convert back to f32. + auto convertOp = + rewriter.create(loc, vecF16); + Value vecF32H = convertOp.getResult(0); + Value vecF32L = convertOp.getResult(1); + // Save into VL value buffer. + Value bufferF32 = create.mem.alignedAlloca(bufferType); + create.vec.storeIE(vecF32H, bufferF32, {litZero}, {}); + create.vec.storeIE(vecF32L, bufferF32, {litVLHalf}, {}); + // Save the remaining values as scalars. + create.scf.forLoop(litZero.getValue(), + remainingScalarValues.getValue(), 1, + [&](SCFBuilder b, Value loopIndex) { + MDBuilder create(b); + IndexExprScope innerScope(b, &middleScope); + IndexExpr l = DimIE(loopIndex); + // Load converted value. + Value f32 = create.krnl.loadIE(bufferF32, {l}); + DimsExpr outputAF = SymListIE(inputAF); + outputAF[E1] = outputAF[E1] + SymIE(lastL); + outputAF[E1] = outputAF[E1] + l; + create.krnl.storeIE(f32, alloc, outputAF); + }); + }); + }); + rewriter.eraseOp(unstickOp); + return success(); + } +}; + +/// Expand stick operation to compiler generated code for suitable patterns, aka +/// all but the 1D and 2DS data layouts at this time. +class StickExpansionPattern : public OpRewritePattern { +public: + StickExpansionPattern(MLIRContext *context, bool enableParallelism = false) + : OpRewritePattern(context, 1), + enableParallel(enableParallelism) {} + + bool enableParallel; + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + ZLowStickOp stickOp, PatternRewriter &rewriter) const override { + + StringAttr layout = stickOp.getLayoutAttr(); + + // Generic way to handle all formats listed below. + if (layout.getValue().equals_insensitive("4D") || + layout.getValue().equals_insensitive("3D") || + layout.getValue().equals_insensitive("2D") || + layout.getValue().equals_insensitive("3DS")) { + return generateStickCodeNoBuffer(rewriter, stickOp, layout); + } + // Otherwise, we don't replace and keep the zdnn call. + return failure(); + } + + /* Version without buffer, more like zdnn */ + LogicalResult generateStickCodeNoBuffer( + PatternRewriter &rewriter, ZLowStickOp stickOp, StringAttr layout) const { + Operation *op = stickOp.getOperation(); + Location loc = stickOp.getLoc(); + MDBuilder create(rewriter, loc); + IndexExprScope allocScope(create.krnl); + + // Compute output dims and rank. + Value input = stickOp.getX(); + Value alloc = stickOp.getOut(); + + DimsExpr outputDims; + create.krnlIE.getShapeAsSymbols(alloc, outputDims); + int64_t rank = outputDims.size(); + + // Info for SIMD Vector Length (VL) and associated types. + int64_t VL = 8; // FP16 VL. + int64_t VLHalf = VL / 2; // FP32 VL. + assert(64 % VL == 0 && "SIMD vector length must divide 64"); + Type f32Type = rewriter.getF32Type(); + VectorType vecF32Type = VectorType::get({VLHalf}, f32Type); + + // Define useful literals. + IndexExpr litZero = LiteralIndexExpr(0); + IndexExpr lit1 = LiteralIndexExpr(1); + IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); + IndexExpr lit64 = LiteralIndexExpr(64); + + // Useful references for indexing dimensions (neg val are not used). + int64_t E1 = rank - 1; + + // Create loop iterations. Note that we iterate over E1 as tiles of 64 + // elements. + ValueRange loopDefs = create.krnl.defineLoops(rank); + DimsExpr lbs(rank, litZero); + DimsExpr ubs = outputDims; + IndexExpr T1 = outputDims[E1].ceilDiv(64); + ubs[E1] = T1; // E1 dim is over tiles. + + // Parallel... + if (enableParallel) { + int64_t parId; + // TODO: may want to check if ub of rank makes sense here. + if (findSuitableParallelDimension(lbs, ubs, 0, rank, parId, 8)) { + create.krnl.parallel(loopDefs[parId]); + onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], + "compiler-generated stickify"); + } else { + onnxToKrnlParallelReport(op, false, -1, -1, + "no dim with enough work in compiler-generated stickify"); + } + } + + // Compute max tiles. It is actually not easy to compute the max number of + // tiles. Since we don't allocate, it is just a "view", we only need to + // index by the "tile size", it is sufficient to assume 2 or more. Tiles are + // 64 elements. + IndexExpr T = LiteralIndexExpr(2); + DimsExpr reallocTileDims = {T, lit64}; + Value allocAsTx64 = + create.mem.reinterpretCast(alloc, litZero.getValue(), reallocTileDims); + + // Outer loop (E1 iterates over tiles of 64 elements). + create.krnl.iterateIE( + loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) { + MDBuilder create(b); + IndexExprScope outerScope(create.krnl, &allocScope); + DimsExpr outerIndices; + getIndexExprList(loopInd, outerIndices); + DimsExpr memAF = outerIndices; + memAF[E1] = memAF[E1] * 64; // Loop index for E1 is in tiles of 64. + Value allocOffset = create.krnl.getLinearOffsetIndexIE(alloc, memAF); + IndexExpr allocTileIndex = SymIE(allocOffset).floorDiv(64); +#if PREFETCH_CSU + DimsExpr prefetchAF = memAF; + // Prefetch current lines. + create.krnl.prefetchIE(input, prefetchAF, /*isWrite*/ false, + /*locality*/ 1); + create.krnl.prefetchIE(alloc, prefetchAF, /*isWrite*/ true, + /*locality*/ 1); +#if PREFETCH_CSU_DIST > 0 + // Prefetch line in advance. + prefetchAF[E1] = prefetchAF[E1] + (PREFETCH_CSU_DIST * 64); + create.krnl.prefetchIE(input, prefetchAF, /*isWrite*/ false, + /*locality*/ 1); + create.krnl.prefetchIE(alloc, prefetchAF, /*isWrite*/ true, + /*locality*/ 1); +#endif +#endif + + const int64_t U = 4; + assert(U * VL <= 64 && "bad unroll"); + create.affine.forIE(litZero, lit64, U * VL, + [&](AffineBuilder &b, ValueRange loopInd) { + MDBuilder create(b); + DimsExpr inputAF; + IndexExprScope innerScope(create.krnl, &outerScope); + SymbolIndexExpr l(loopInd[0]); + getIndexExprList(memAF, inputAF); + // E1: add the "l" local E1 offset. + inputAF[E1] = inputAF[E1] + l; + Value vecF32H[U], vecF32L[U], vecF16[U]; + for (int64_t i = 0; i < U; ++i) { + LiteralIndexExpr iH(i * VL), iL(i * VL + VL / 2); + vecF32H[i] = create.vec.loadIE( + vecF32Type, input, inputAF, {iH.getValue()}); + vecF32L[i] = create.vec.loadIE( + vecF32Type, input, inputAF, {iL.getValue()}); + } + for (int64_t i = 0; i < U; ++i) { + vecF16[i] = rewriter.create( + loc, vecF32H[i], vecF32L[i]); + } + for (int64_t i = 0; i < U; ++i) { + create.vec.storeIE(vecF16[i], allocAsTx64, + {SymIE(allocTileIndex), l + (i * VL)}, {}); + } + }); + }); + + rewriter.eraseOp(stickOp); + return success(); + } +}; + +/*! + * Function pass that optimizes ZLowIR. + */ +class ZLowStickExpansionPass + : public PassWrapper> { + +public: + ZLowStickExpansionPass(bool enableParallel) + : PassWrapper>(), + enableParallel(enableParallel) {} + + bool enableParallel; + + StringRef getArgument() const override { return "zlow-stick-expansion"; } + + StringRef getDescription() const override { + return "ZLow Stick/Unstick Ops expansion pass."; + } + + void runOnOperation() override { + Operation *function = getOperation(); + + ConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext(), enableParallel); + patterns.insert(&getContext(), enableParallel); + // patterns.insert(&getContext()); + + if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) + return signalPassFailure(); + } +}; + +std::unique_ptr createZLowStickExpansionPass(bool enableParallel) { + return std::make_unique(enableParallel); +} + +} // namespace zlow +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToAffine/CMakeLists.txt b/src/Conversion/KrnlToAffine/CMakeLists.txt index 7d8fe798a8..eff3975f6a 100644 --- a/src/Conversion/KrnlToAffine/CMakeLists.txt +++ b/src/Conversion/KrnlToAffine/CMakeLists.txt @@ -8,6 +8,7 @@ add_onnx_mlir_library(OMKrnlToAffine KrnlLoad.cpp KrnlMatmul.cpp KrnlMemset.cpp + KrnlPrefetch.cpp KrnlStore.cpp KrnlTerminator.cpp KrnlToAffineHelper.cpp diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index 6ee8697a1a..1ef6b0467a 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -967,6 +967,7 @@ void ConvertKrnlToAffinePass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); @@ -1029,6 +1030,7 @@ void populateKrnlToAffineConversion(TypeConverter &typeConverter, typeConverter, patterns, ctx); krnl::populateLoweringKrnlMatmultOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlMemsetOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlPrefetchOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlTerminatorOpPattern(typeConverter, patterns, ctx); } diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp index 847fc865de..45d5b211a9 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp @@ -81,6 +81,9 @@ void populateLoweringKrnlMatmultOpPattern(mlir::TypeConverter &typeConverter, void populateLoweringKrnlMemsetOpPattern(mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateLoweringKrnlPrefetchOpPattern(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + void populateLoweringKrnlTerminatorOpPattern(mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); diff --git a/src/Conversion/KrnlToAffine/KrnlPrefetch.cpp b/src/Conversion/KrnlToAffine/KrnlPrefetch.cpp new file mode 100644 index 0000000000..43ce6d4a50 --- /dev/null +++ b/src/Conversion/KrnlToAffine/KrnlPrefetch.cpp @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===--------------- KrnlGetLinearOffsetIndex.cpp - -----------------------===// +// +// Copyright 2024- The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlPrefetchOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_affine" + +using namespace mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlPrefetchOpLowering : public ConversionPattern { +public: + explicit KrnlPrefetchOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlPrefetchOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MultiDialectBuilder create(rewriter, loc); + + auto krnlOp = llvm::cast(op); + KrnlPrefetchOpAdaptor operandAdaptor(krnlOp); + + Operation *affineOp = create.affine.prefetch(operandAdaptor.getMemref(), + operandAdaptor.getMap(), operandAdaptor.getIndices(), + operandAdaptor.getIsWrite(), operandAdaptor.getLocalityHint(), + operandAdaptor.getIsDataCache()); + + rewriter.replaceOp(op, affineOp); + return success(); + } +}; + +void populateLoweringKrnlPrefetchOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp index 357639974f..473f1f0938 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp @@ -37,27 +37,12 @@ struct ONNXReshapeOpLowering : public OpConversionPattern { ValueRange operands = adaptor.getOperands(); Value data = adaptor.getData(); - Value inputTensor = reshapeOp.getData(); - Value outputTensor = reshapeOp.getReshaped(); - int64_t inputRank = getRank(inputTensor.getType()); - int64_t outputRank = getRank(outputTensor.getType()); - // If reshape does not change dimensions or it is an identity, just replace // the output with the input. - // It is an identity if at least (N-1) out of N dimensions are equal. We - // don't need to care about the different dimension, it is maybe because of - // DimAnalysis failed to handle it. - if (inputRank == outputRank) { - int nSameDims = 0; - for (int64_t i = 0; i < inputRank; ++i) { - if (dimAnalysis->sameDim(inputTensor, i, outputTensor, i)) - nSameDims++; - } - if (nSameDims >= inputRank - 1) { - LLVM_DEBUG(llvm::dbgs() << "Lowering reshape to identity\n"); - rewriter.replaceOp(op, data); - return success(); - } + if (isIdentityReshape(reshapeOp, dimAnalysis)) { + LLVM_DEBUG(llvm::dbgs() << "Lowering reshape to identity\n"); + rewriter.replaceOp(op, data); + return success(); } // Convert the output type to MemRefType. diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 5ad1b9c5c3..53475325a6 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -104,6 +104,20 @@ Value KrnlBuilder::getLinearOffsetIndexIE( return b().create(loc(), memref, indexValues); } +void KrnlBuilder::prefetch(Value memref, ValueRange indices, bool isWrite, + unsigned localityHint, bool isDataCache) { + b().create( + loc(), memref, indices, isWrite, localityHint, isDataCache); +} + +void KrnlBuilder::prefetchIE(Value memref, ArrayRef indices, + bool isWrite, unsigned localityHint, bool isDataCache) { + SmallVector indexValues; + IndexExpr::getValues(indices, indexValues); + b().create( + loc(), memref, indexValues, isWrite, localityHint, isDataCache); +} + void KrnlBuilder::seqstore( mlir::Value element, mlir::Value seq, mlir::Value index) const { b().create(loc(), element, seq, index); diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index 2856486c0e..1a50768542 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -43,11 +43,18 @@ struct KrnlBuilder : public DialectBuilder { void storeIE(mlir::Value val, mlir::Value memref, mlir::ArrayRef indices) const; + // Get linear offset for given memref at given index values. mlir::Value getLinearOffsetIndex( mlir::Value memref, mlir::ValueRange indices = {}) const; mlir::Value getLinearOffsetIndexIE( mlir::Value memref, mlir::ArrayRef indices) const; + // Prefetch with identity map. + void prefetch(mlir::Value memref, mlir::ValueRange indices, bool isWrite, + unsigned localityHint, bool isDataCache = true); + void prefetchIE(mlir::Value memref, mlir::ArrayRef indices, + bool isWrite, unsigned localityHint, bool isDataCache = true); + void seqstore(mlir::Value element, mlir::Value seq, mlir::Value index) const; void seqstore(mlir::Value element, mlir::Value seq, IndexExpr index) const; diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index 4e8b7376bf..025cce55f7 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -753,10 +753,69 @@ def KrnlGetLinearOffsetIndexOp : Op]> { + let summary = "A Krnl operation to compute a linear offset index from a N-D index."; + + let description = [{ + Given a MemRef and an N-D index (id_1, id_2, ..., id_n), prefetch the memory + location pointed by this memory reference. + }]; + + let arguments = (ins Arg:$memref, + Variadic:$indices, + BoolAttr:$isWrite, + ConfinedAttr, IntMaxValue<3>]>:$localityHint, + BoolAttr:$isDataCache, + AffineMapAttr:$map + ); + // let assemblyFormat = [{$memref `[` $indices `]` attr-dict `:` type($memref)}]; + let builders = [ + /// Builds an op with an identity map and operands. + OpBuilder<(ins "Value":$memref, "ValueRange":$indices, "bool":$isWrite, + "unsigned":$localityHint, "bool":$isDataCache)>, + OpBuilder<(ins "Value":$memref, "bool":$isWrite, + "unsigned":$localityHint, "bool":$isDataCache)>, + /// Builds an op with the specified map and its operands. + OpBuilder<(ins "Value":$memref, "AffineMap":$map, + "ValueRange":$mapOperands, "bool":$isWrite, + "unsigned":$localityHint, "bool":$isDataCache)> + ]; + let extraClassDeclaration = [{ + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 0; } + Value getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } + MemRefType getMemRefType() { + return getMemref().getType().cast(); + } + + /// Implements the AffineMapAccessInterface. + /// Returns the AffineMapAttr associated with 'memref'. + // Default version is fine. + + /// Returns the affine map used to index the memref for this operation. + AffineMapAttr getAffineMapAttr() { return getMapAttr(); } + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + + /// Get affine map operands. + operand_range getMapOperands() { return getIndices(); } + + static StringRef getMapAttrStrName() { return "map"; } + static StringRef getIsWriteAttrStrName() { return "isWrite"; } + static StringRef getLocalityHintAttrStrName() { return "localityHint"; } + static StringRef getIsDataCacheAttrStrName() { return "isDataCache"; } + }]; + + let hasCustomAssemblyFormat = 1; } + def KrnlMovableOp : Op { let summary = "Krnl movable operation"; let description = [{ diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index c3ce481185..17f7693d92 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -1118,6 +1118,111 @@ void KrnlGetLinearOffsetIndexOp::print(OpAsmPrinter &p) { p << " : " << getMemRefType(); } +//===----------------------------------------------------------------------===// +// KrnlPrefetchOp +//===----------------------------------------------------------------------===// + +void KrnlPrefetchOp::build(OpBuilder &builder, OperationState &result, + Value memref, AffineMap map, ValueRange mapOperands, bool isWrite, + unsigned localityHint, bool isDataCache) { + assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); + result.addOperands(memref); + result.addOperands(mapOperands); + result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); + result.addAttribute(getIsWriteAttrStrName(), builder.getBoolAttr(isWrite)); + result.addAttribute( + getLocalityHintAttrStrName(), builder.getI32IntegerAttr(localityHint)); + result.addAttribute( + getIsDataCacheAttrStrName(), builder.getBoolAttr(isDataCache)); +} + +void KrnlPrefetchOp::build(OpBuilder &builder, OperationState &result, + Value memref, ValueRange indices, bool isWrite, unsigned localityHint, + bool isDataCache) { + auto memrefType = llvm::cast(memref.getType()); + int64_t rank = memrefType.getRank(); + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + auto map = + rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); + build(builder, result, memref, map, indices, isWrite, localityHint, + isDataCache); +} + +void KrnlPrefetchOp::build(OpBuilder &builder, OperationState &result, + Value memref, bool isWrite, unsigned localityHint, bool isDataCache) { + build(builder, result, memref, {}, isWrite, localityHint, isDataCache); +} + +// +// krnl.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> +// Code lifted from affine prefetch as is. +// I have seen parsing errors when multiple '#x' are used in the indices, could +// not tell why. +// krnl.prefetch %arg0[%1#0, %1#1, %3], read, locality<3>, data : +// memref<8x256x512xf32> +// With only one, it works. +// + +ParseResult KrnlPrefetchOp::parse(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); + auto indexTy = builder.getIndexType(); + + MemRefType type; + OpAsmParser::UnresolvedOperand memrefInfo; + IntegerAttr hintInfo; + auto i32Type = parser.getBuilder().getIntegerType(32); + StringRef readOrWrite, cacheType; + + AffineMapAttr mapAttr; + SmallVector mapOperands; + if (parser.parseOperand(memrefInfo) || + parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, + KrnlPrefetchOp::getMapAttrStrName(), result.attributes) || + parser.parseComma() || parser.parseKeyword(&readOrWrite) || + parser.parseComma() || parser.parseKeyword("locality") || + parser.parseLess() || + parser.parseAttribute(hintInfo, i32Type, + KrnlPrefetchOp::getLocalityHintAttrStrName(), result.attributes) || + parser.parseGreater() || parser.parseComma() || + parser.parseKeyword(&cacheType) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(memrefInfo, type, result.operands) || + parser.resolveOperands(mapOperands, indexTy, result.operands)) + return failure(); + + if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + return parser.emitError( + parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); + result.addAttribute(KrnlPrefetchOp::getIsWriteAttrStrName(), + parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + + if (!cacheType.equals("data") && !cacheType.equals("instr")) + return parser.emitError( + parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); + + result.addAttribute(KrnlPrefetchOp::getIsDataCacheAttrStrName(), + parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + + return success(); +} + +void KrnlPrefetchOp::print(OpAsmPrinter &p) { + p << " " << getMemref() << '['; + AffineMapAttr mapAttr = + (*this)->getAttrOfType(getMapAttrStrName()); + if (mapAttr) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); + p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", " + << "locality<" << getLocalityHint() << ">, " + << (getIsDataCache() ? "data" : "instr"); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(), + getIsDataCacheAttrStrName(), getIsWriteAttrStrName()}); + p << " : " << getMemRefType(); +} + //===----------------------------------------------------------------------===// // KrnlMemcpyOp //===----------------------------------------------------------------------===// diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index ad3f6e72cb..41449d40fd 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -524,11 +524,9 @@ struct GenericAffineBuilder final : DialectBuilder { void storeIE(mlir::Value val, mlir::Value memref, llvm::ArrayRef indices, mlir::ValueRange offsets) const; - void prefetch(mlir::Value memref, mlir::AffineMap map, - llvm::ArrayRef operands, bool isWrite, unsigned localityHint, - bool isDataCache = true) const; - void prefetchIE(mlir::Value memref, llvm::ArrayRef indices, - bool isWrite, unsigned localityHint, bool isDataCache = true) const; + mlir::Operation *prefetch(mlir::Value memref, mlir::AffineMap map, + mlir::ValueRange indices, bool isWrite, unsigned localityHint, + bool isDataCache = true); void forIE(IndexExpr lb, IndexExpr ub, int64_t step, mlir::function_ref builderFn) diff --git a/src/Dialect/Mlir/DialectBuilder.hpp.inc b/src/Dialect/Mlir/DialectBuilder.hpp.inc index 48fffe5143..bf424ae920 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp.inc +++ b/src/Dialect/Mlir/DialectBuilder.hpp.inc @@ -64,6 +64,15 @@ inline void GenericAffineBuilder::storeIE(mlir::Value val, store(val, memref, computedIndices); } +template +inline mlir::Operation *GenericAffineBuilder::prefetch( + mlir::Value memref, mlir::AffineMap map, mlir::ValueRange indices, + bool isWrite, unsigned localityHint, bool isDataCache) { + llvm::SmallVector indexArray(indices); + return b().template create( + loc(), memref, map, indexArray, isWrite, localityHint, isDataCache); +} + template inline void GenericAffineBuilder::forIE(IndexExpr lb, IndexExpr ub, int64_t step, @@ -85,24 +94,6 @@ inline void GenericAffineBuilder::forIE(IndexExpr lb, }); } -template -void GenericAffineBuilder::prefetch(mlir::Value memref, - mlir::AffineMap map, llvm::ArrayRef operands, bool isWrite, - unsigned localityHint, bool isDataCache) const { - b().template create( - loc(), memref, map, operands, isWrite, localityHint, isDataCache); -} - -template -inline void GenericAffineBuilder::prefetchIE( - mlir::Value memref, llvm::ArrayRef indices, bool isWrite, - unsigned localityHint, bool isDataCache) const { - llvm::SmallVector operands; - mlir::AffineMap map; - IndexExpr::getAffineMapAndOperands(indices, map, operands); - prefetch(memref, map, operands, isWrite, localityHint, isDataCache); -} - template inline void GenericAffineBuilder::forIE( llvm::SmallVectorImpl &lbs, diff --git a/src/Dialect/Mlir/IndexExpr.cpp b/src/Dialect/Mlir/IndexExpr.cpp index 62929bb569..fe08d84486 100644 --- a/src/Dialect/Mlir/IndexExpr.cpp +++ b/src/Dialect/Mlir/IndexExpr.cpp @@ -454,55 +454,10 @@ IndexExprKind IndexExpr::getKind() const { return getObj().getKind(); } void IndexExpr::debugPrint(const std::string &msg) const { LLVM_DEBUG({ - llvm::dbgs() << msg.c_str(); - if (!isDefined()) { - llvm::dbgs() << " undefined\n"; - return; - } - if (isLiteral()) { - if (isFloat()) - llvm::dbgs() << " floatLiteral(" << getFloatLiteral() << ")"; - else - llvm::dbgs() << " literal(" << (long long)getLiteral() << ")"; - } - if (isFloat()) - llvm::dbgs() << " isFloat"; - if (hasAffineExpr()) - llvm::dbgs() << " hasAffine"; - if (hasValue()) { - llvm::dbgs() << " hasValue"; - auto op = getValue().getDefiningOp(); - if (op) { - std::string str; - llvm::raw_string_ostream os(str); - op->print(os); - llvm::dbgs() << "( \"" << str.c_str() << "\")"; - } else - llvm::dbgs() << "(op not found)"; - } - if (isAffine()) - llvm::dbgs() << " is affine"; - switch (getKind()) { - case IndexExprKind::NonAffine: - llvm::dbgs() << " kind(non-affine)"; - break; - case IndexExprKind::Questionmark: - llvm::dbgs() << " kind(questionmark)"; - break; - case IndexExprKind::Predicate: - llvm::dbgs() << " kind(predicate)"; - break; - case IndexExprKind::Affine: - llvm::dbgs() << " kind(affine)"; - break; - case IndexExprKind::Dim: - llvm::dbgs() << " kind(dim)"; - break; - case IndexExprKind::Symbol: - llvm::dbgs() << " kind(symbol)"; - break; - } - llvm::dbgs() << " scope(0x " << (long long unsigned)getScopePtr() << ")\n"; + if (!indexExprObj) + llvm::dbgs() << msg.c_str() << " undefined\n"; + else + indexExprObj->debugPrint(msg); }); } diff --git a/src/Dialect/Mlir/IndexExpr.hpp b/src/Dialect/Mlir/IndexExpr.hpp index 048419d168..dba5aeef5a 100644 --- a/src/Dialect/Mlir/IndexExpr.hpp +++ b/src/Dialect/Mlir/IndexExpr.hpp @@ -365,7 +365,7 @@ class IndexExprScope { int getNumDims() const { return dims.size(); } int getNumSymbols() const { return symbols.size(); } - // Debug (enable using --debug-only=index_expr, for example). + // Debug (enable using --debug-only=index-expr, for example). void debugPrint(const std::string &msg) const; private: @@ -820,6 +820,14 @@ class SymbolIndexExpr : public IndexExpr { SymbolIndexExpr(IndexExprImpl *otherObjPtr); }; +//===----------------------------------------------------------------------===// +// Shortcuts for Index Expr subclasses, to render code more readable. +//===----------------------------------------------------------------------===// + +using LitIE = LiteralIndexExpr; +using SymIE = SymbolIndexExpr; +using DimIE = DimIndexExpr; + //===----------------------------------------------------------------------===// // Additional operators with integer values in first position //===----------------------------------------------------------------------===// @@ -857,6 +865,18 @@ void getIndexExprList( outputList.emplace_back(INDEX_EXPR(item)); } +inline llvm::SmallVector DimListIE(mlir::ValueRange range) { + llvm::SmallVector outputList; + getIndexExprList(range, outputList); + return outputList; +} + +inline llvm::SmallVector SymListIE(mlir::ValueRange range) { + llvm::SmallVector outputList; + getIndexExprList(range, outputList); + return outputList; +} + // Create a list of IndexExpr of kind INDEX_EXPR from another list of IndexExpr. template void getIndexExprList(llvm::SmallVectorImpl &inputList, @@ -866,6 +886,20 @@ void getIndexExprList(llvm::SmallVectorImpl &inputList, outputList.emplace_back(INDEX_EXPR(item)); } +inline llvm::SmallVector DimListIE( + llvm::SmallVectorImpl &inputList) { + llvm::SmallVector outputList; + getIndexExprList(inputList, outputList); + return outputList; +} + +inline llvm::SmallVector SymListIE( + llvm::SmallVectorImpl &inputList) { + llvm::SmallVector outputList; + getIndexExprList(inputList, outputList); + return outputList; +} + // Create a list of IndexExpr of kind LiteralIndexExpr from a list of integers. void getIndexExprListFromInt(mlir::ArrayRef inputList, llvm::SmallVectorImpl &outputList); diff --git a/src/Dialect/Mlir/IndexExprBuilder.cpp b/src/Dialect/Mlir/IndexExprBuilder.cpp index 4fbe8b3fe0..16491953a8 100644 --- a/src/Dialect/Mlir/IndexExprBuilder.cpp +++ b/src/Dialect/Mlir/IndexExprBuilder.cpp @@ -434,7 +434,8 @@ IndexExpr IndexExprBuilder::isTileFull( } // True if i <= (UB - block), namely UB - block - i >= 0. // Affine expressions compared to >= 0 - IndexExpr res = UB - block - i; + IndexExpr res = UB - block; + res = res - i; return res; } diff --git a/src/Dialect/Mlir/IndexExprDetail.cpp b/src/Dialect/Mlir/IndexExprDetail.cpp index 556e84e011..aa65f6771a 100644 --- a/src/Dialect/Mlir/IndexExprDetail.cpp +++ b/src/Dialect/Mlir/IndexExprDetail.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include @@ -375,7 +376,9 @@ AffineExpr IndexExprImpl::getAffineExpr() { return affineExpr; } - assert(isInCurrentScope() && + // Literal never have to be in scope, so bypass in scope test when that is the + // case. + assert((isLiteral() || isInCurrentScope()) && "create an affine expression only for index exprs in current scope"); if (isLiteral()) { @@ -525,4 +528,58 @@ void IndexExprImpl::setLiteral(const IndexExprImpl &obj) { setLiteral(obj.getLiteral()); } +void IndexExprImpl::debugPrint(const std::string &msg) { + LLVM_DEBUG({ + llvm::dbgs() << msg.c_str(); + if (!isDefined()) { + llvm::dbgs() << " undefined\n"; + return; + } + if (isLiteral()) { + if (isFloatType()) + llvm::dbgs() << " floatLiteral(" << getFloatLiteral() << ")"; + else + llvm::dbgs() << " literal(" << (long long)getLiteral() << ")"; + } + if (isFloatType()) + llvm::dbgs() << " isFloat"; + if (hasAffineExpr()) + llvm::dbgs() << " hasAffine"; + if (hasValue()) { + llvm::dbgs() << " hasValue"; + auto op = getValue().getDefiningOp(); + if (op) { + std::string str; + llvm::raw_string_ostream os(str); + op->print(os); + llvm::dbgs() << "( \"" << str.c_str() << "\")"; + } else + llvm::dbgs() << "(op not found)"; + } + if (isAffine()) + llvm::dbgs() << " is affine"; + switch (getKind()) { + case IndexExprKind::NonAffine: + llvm::dbgs() << " kind(non-affine)"; + break; + case IndexExprKind::Questionmark: + llvm::dbgs() << " kind(questionmark)"; + break; + case IndexExprKind::Predicate: + llvm::dbgs() << " kind(predicate)"; + break; + case IndexExprKind::Affine: + llvm::dbgs() << " kind(affine)"; + break; + case IndexExprKind::Dim: + llvm::dbgs() << " kind(dim)"; + break; + case IndexExprKind::Symbol: + llvm::dbgs() << " kind(symbol)"; + break; + } + llvm::dbgs() << " scope(0x " << (long long unsigned)getScopePtr() << ")\n"; + }); +} + } // namespace onnx_mlir diff --git a/src/Dialect/Mlir/IndexExprDetail.hpp b/src/Dialect/Mlir/IndexExprDetail.hpp index 389a5ee345..9a0ba29d9a 100644 --- a/src/Dialect/Mlir/IndexExprDetail.hpp +++ b/src/Dialect/Mlir/IndexExprDetail.hpp @@ -88,6 +88,8 @@ class IndexExprImpl { void setLiteral(double val); void setLiteral(const IndexExprImpl &obj); + void debugPrint(const std::string &msg); + private: // Init for internal use only. void init(bool isDefined, bool isIntLit, bool isFloatLit, IndexExprKind type, diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td index e665f19c35..9e50b07e01 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -144,6 +144,10 @@ def HaveSameStaticShape: Constraint< CPred<"onnx_mlir::haveSameStaticShape($0, $1)">, "Two tensors have the same static shape">; +def IsIdentityReshape: Constraint< + CPred<"onnx_mlir::isIdentityReshape($0, $1)">, + "Reshape is identity operation">; + // Create a unit constant that will be used as none input. def CreateNoneValue : NativeCodeCall<"$_builder.create($_loc).getResult()">; @@ -662,7 +666,7 @@ def RemoveIdentityReshapePattern2: Pat< // Remove the reshape. (replaceWithValue $val), // Check that val and out have the same static shape. - [(HaveSameStaticShape $out, $val)]>; + [(IsIdentityReshape $out, $val)]>; def GetReturnTypeForMatMulOpND2D: NativeCodeCall< "onnx_mlir::getReturnTypeForMatMulOpND2D($0, $1)" diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 8f4292b39c..056cac495a 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -731,6 +731,54 @@ bool hasIntegerPowerExponent(ONNXPowOp *op, int64_t &exponentValue) { return false; } +//===----------------------------------------------------------------------===// +// Support for ReshapeOp. +//===----------------------------------------------------------------------===// + +// Return true if reshape does nothing, aka it returns the same as the input. +// Use dimAnalysis if provided. + +bool isIdentityReshape( + Value inputTensor, Value outputTensor, const DimAnalysis *dimAnalysis) { + if (!hasShapeAndRank(inputTensor) || !hasShapeAndRank(outputTensor)) + return false; + // Check if same rank. + Type inputType = inputTensor.getType(); + Type outputType = outputTensor.getType(); + int64_t inputRank = getRank(inputType); + int64_t outputRank = getRank(outputType); + if (inputRank != outputRank) + return false; + + // Reshape is an identity if at least (N-1) out of N dimensions are equal. We + // don't need to care about the different dimension, it is maybe because of + // DimAnalysis failed to handle it. + int nSameDims = 0; + ArrayRef inputShape = getShape(inputType); + ArrayRef outputShape = getShape(outputType); + for (int64_t i = 0; i < inputRank; ++i) { + if (inputShape[i] != ShapedType::kDynamic && + inputShape[i] == outputShape[i]) + nSameDims++; + else if (dimAnalysis && + dimAnalysis->sameDim(inputTensor, i, outputTensor, i)) + nSameDims++; + } + // Its basically ok to miss one as it then must be equal. + if (nSameDims >= inputRank - 1) + return true; + + return false; +} + +bool isIdentityReshape( + ONNXReshapeOp reshapeOp, const DimAnalysis *dimAnalysis) { + // Check if ranked and shaped. + Value inputTensor = reshapeOp.getData(); + Value outputTensor = reshapeOp.getReshaped(); + return isIdentityReshape(inputTensor, outputTensor, dimAnalysis); +} + //===----------------------------------------------------------------------===// // Support for location. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 84722c2dbb..2c2d089162 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -286,6 +286,19 @@ bool areDimsFromConcat(mlir::Value val); /// Get all dimensions that are stored by the value. void getDims(mlir::Value val, llvm::SmallVectorImpl &dims); +//===----------------------------------------------------------------------===// +// Support for ReshapeOp. +//===----------------------------------------------------------------------===// + +// Return true if reshape does nothing, aka it returns the same as the input. +// Use dimAnalysis if provided. + +bool isIdentityReshape( + mlir::ONNXReshapeOp reshapeOp, const DimAnalysis *dimAnalysis = nullptr); + +bool isIdentityReshape(mlir::Value input, mlir::Value output, + const DimAnalysis *dimAnalysis = nullptr); + //===----------------------------------------------------------------------===// // Support for location. //===----------------------------------------------------------------------===// diff --git a/test/mlir/conversion/krnl_to_affine/krnl_to_affine_with_canonicalize.mlir b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_with_canonicalize.mlir index 4db502807c..33346613fe 100644 --- a/test/mlir/conversion/krnl_to_affine/krnl_to_affine_with_canonicalize.mlir +++ b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_with_canonicalize.mlir @@ -102,3 +102,37 @@ func.func @krnl_get_linear_offset_index_2(%arg0: memref, %arg1: // CHECK: return [[VAR_0_]] : index // CHECK: } } + +// ----- + +#map = affine_map<(d0) -> (d0 + 64)> +func.func @prefetch(%arg0: memref<256x512xf32>) -> memref<256x512xf32> attributes {input_names = ["x"], output_names = ["output"]} { + %alloc = memref.alloc() {alignment = 16 : i64} : memref<256x512xf32> + %0:2 = krnl.define_loops 2 + krnl.iterate(%0#0, %0#1) with (%0#0 -> %arg1 = 0 to 256, %0#1 -> %arg2 = 0 to 512){ + %1:2 = krnl.get_induction_var_value(%0#0, %0#1) : (!krnl.loop, !krnl.loop) -> (index, index) + %2 = krnl.load %arg0[%1#0, %1#1] : memref<256x512xf32> + %3 = affine.apply #map(%1#1) + krnl.prefetch %arg0[%1#0, %3], read, locality<3>, data : memref<256x512xf32> + %4 = krnl.load %arg0[%1#0, %1#1] : memref<256x512xf32> + %5 = arith.addf %2, %4 : f32 + krnl.store %5, %alloc[%1#0, %1#1] : memref<256x512xf32> + } + return %alloc : memref<256x512xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @prefetch +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x512xf32>) -> memref<256x512xf32> attributes {input_names = ["x"], llvm.emit_c_interface, output_names = ["output"]} { +// CHECK: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<256x512xf32> +// CHECK: affine.for [[I_0_:%.+]] = 0 to 256 { +// CHECK: affine.for [[I_1_:%.+]] = 0 to 512 { +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = affine.load [[PARAM_0_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<256x512xf32> +// CHECK: affine.prefetch [[PARAM_0_]]{{.}}[[I_0_]], [[I_1_]] + 64], write, locality<0>, data : memref<256x512xf32> +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.load [[PARAM_0_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<256x512xf32> +// CHECK: [[VAR_2_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: affine.store [[VAR_2_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<256x512xf32> +// CHECK: } +// CHECK: } +// CHECK: return [[RES_]] : memref<256x512xf32> +// CHECK: } +} diff --git a/utils/make-report.py b/utils/make-report.py index b9fe8fad93..3e4b369db4 100755 --- a/utils/make-report.py +++ b/utils/make-report.py @@ -524,14 +524,14 @@ def make_report(stat_message): print("") stat_details = "" if supported_only: - stat_details = ", supported ops" + stat_details = " supported ops" else: - stat_details = ", all ops" + stat_details = " all ops" if min_percent_reporting > 0: stat_details += ", " + str(min_percent_reporting) + "%+ exec time" - stat_details += ", ordered_by " + sorting_preference + stat_details += " ordered_by " + sorting_preference if has_timing: - stat_details += ", tot_time {:.7f}".format(tot_time * time_unit) + stat_details += ", tot_time, {:.7f}".format(tot_time * time_unit) print("Statistics start" + stat_details) for key in sorted(sorted_output): print(sorted_output[key])