From c585f0365be105698032edf4bc93a9d146ab3533 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 11 Apr 2024 22:29:17 +0900 Subject: [PATCH] [Cherry-pick] Fixing the location of DimAnalysis in onnx-to-zhigh pass and some rules in zhigh-to-onnx pass The onnx-to-zhigh pass has two phases: 1) converting multiple onnx ops into a single zhigh op, and 2) converting a single onnx op to a single zhigh op, where the second phase uses DimAnalysis (Patterns in the 1st phase at this moment does not use DimAnalysis) The problem is DimAnalysis is currently called before the 1st phase, which is not good because the 1st phase may change the IR so the information from DimAnalysis is obsoleted to the 2nd phase. Correct position for DimAnalysis would be just before the 2nd phase. Other than that, this PR changes slightly the rules in zhigh-to-onnx pass so that for binary ops, only one input (instead of two) that is from stick would be enough to trigger the rule to convert a zhigh op back to an onnx op. Resolves #2789 --------- Signed-off-by: Tung D. Le (cherry picked from commit 80a63f2c1f497f50c1f3ebfb3372dfc3ef50ff2d) Signed-off-by: Charles Volzka --- .../Conversion/ONNXToZHigh/CMakeLists.txt | 1 + .../ONNXToZHigh/DevicePlacement.cpp | 2 +- .../Conversion/ONNXToZHigh/ONNXToZHigh.cpp | 13 ++- .../Conversion/ONNXToZHigh/ZHighToONNX.cpp | 2 + .../Conversion/ONNXToZHigh/ZHighToONNX.td | 95 ++++++++++++++----- .../instrument/add-onnx-zhigh-level.mlir | 6 +- .../nnpa/conversion/onnx-to-zhigh/conv.mlir | 4 +- .../nnpa/conversion/onnx-to-zhigh/matmul.mlir | 8 +- .../conversion/onnx-to-zhigh/softmax.mlir | 20 ++-- 9 files changed, 103 insertions(+), 48 deletions(-) diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt index 2e50d0797b..d6f0af7e6a 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt @@ -14,6 +14,7 @@ add_onnx_mlir_library(OMONNXToZHigh OMNNPACompilerOptions OMONNXOps OMONNXToKrnl + OMShapeInferencePass OMZHighOps ACCEL_INCLUDE_DIRS PRIVATE diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp index 82be5e277b..bbb39102c7 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp @@ -196,7 +196,7 @@ void DevicePlacementPass::runOnOperation() { // Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh. // E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv. RewritePatternSet Patterns2(context); - getONNXToZHighOneOpPatterns(Patterns2); + getONNXToZHighMultipleOpPatterns(Patterns2); (void)applyAnalysisConversion(module, target, std::move(Patterns2), ConversionConfig{.legalizableOps = &legalizedOps2}); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp index 6560388bce..59d0e997e8 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp @@ -22,6 +22,7 @@ #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp" using namespace mlir; @@ -328,16 +329,13 @@ void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) { patterns.insert(context); patterns.insert(context); patterns.insert(context); + // Shape inference for newly-added operations. + getShapeInferencePatterns(patterns); } void ONNXToZHighLoweringPass::runOnOperation() { ModuleOp module = getOperation(); - // Run the unknown dimension analysis to help check equality of unknown - // dimensions at compile time. - onnx_mlir::DimAnalysis dimAnalysis(module); - dimAnalysis.analyze(); - // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); @@ -363,6 +361,11 @@ void ONNXToZHighLoweringPass::runOnOperation() { // It's ok to fail. (void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns)); + // Run the unknown dimension analysis to help check equality of unknown + // dimensions at compile time. + onnx_mlir::DimAnalysis dimAnalysis(module); + dimAnalysis.analyze(); + // Single ONNX to ZHigh operation lowering. RewritePatternSet patterns(&getContext()); onnx_mlir::getONNXToZHighOneOpPatterns(patterns); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp index 1eb25c409c..60e11ca41d 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp @@ -59,6 +59,8 @@ void ZHighToONNXLoweringPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); + zhigh::ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext()); + zhigh::ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext()); (void)applyPatternsAndFoldGreedily(function, std::move(patterns)); } diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td index 02a06cabe8..3012c2160c 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td @@ -37,31 +37,52 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create($_loc, $0.getT // ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighAddPattern : Pat< - (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (ONNXAddOp $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighAddPattern1 : Pat< + (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), $y)), + (ONNXAddOp $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)] >; +def replaceZHighAddPattern2 : Pat< + (ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_))), + (ONNXAddOp (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)] +>; //===----------------------------------------------------------------------===// // ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMulPattern : Pat< - (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (ONNXMulOp $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighMulPattern1 : Pat< + (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), $y)), + (ONNXMulOp $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighMulPattern2 : Pat< + (ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_))), + (ONNXMulOp (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [], + (addBenefit 0) >; //===----------------------------------------------------------------------===// // ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighSubPattern : Pat< - (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (ONNXSubOp $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighSubPattern1 : Pat< + (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), $y)), + (ONNXSubOp $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighSubPattern2 : Pat< + (ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_))), + (ONNXSubOp (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], + (addBenefit 0) >; //===----------------------------------------------------------------------===// @@ -69,30 +90,54 @@ def replaceZHighSubPattern : Pat< // %X),(ZHighStickOp %Y)) // Note: turn off this pattern since NNPA is faster at this moment. //===----------------------------------------------------------------------===// -// def replaceZHighDivPattern : Pat< -// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), -// (ONNXDivOp $x, $y), -// [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] -// >; +//def replaceZHighDivPattern1 : Pat< +// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)), +// (ONNXDivOp $x, (ZHighUnstickOp $y)), +// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], +// (addBenefit 1) +//>; +// +//def replaceZHighDivPattern2 : Pat< +// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))), +// (ONNXDivOp (ZHighUnstickOp $x), $y), +// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], +// (addBenefit 0) +//>; //===----------------------------------------------------------------------===// // ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMinPattern : Pat< - (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (CreateONNXMinOp $u, $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighMinPattern1 : Pat< + (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), $y)), + (CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighMinPattern2 : Pat< + (ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_))), + (CreateONNXMinOp $u, (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], + (addBenefit 0) >; //===----------------------------------------------------------------------===// // ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMaxPattern : Pat< - (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (CreateONNXMaxOp $u, $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighMaxPattern1 : Pat< + (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), $y)), + (CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighMaxPattern2 : Pat< + (ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_))), + (CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], + (addBenefit 0) >; //===----------------------------------------------------------------------===// diff --git a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir index f85e8021e7..059c3bcfb8 100644 --- a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir +++ b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir @@ -5,7 +5,8 @@ func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x1xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) {onnx_node_name = "onnx.Add1"} : (tensor<10x10xf32>, tensor<10x1xf32>) -> tensor<*xf32> %1 = "onnx.Add"(%arg0, %0) {onnx_node_name = "onnx.Add2"} : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32> - "onnx.Return"(%1) : (tensor<*xf32>) -> () + %2 = "onnx.Relu"(%1) {onnx_node_name = "onnx.Relu"} : (tensor<*xf32>) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () } // CHECK-LABEL: func.func @test_instrument_add_onnx_zhigh @@ -21,6 +22,9 @@ func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : ten // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 5 : i64} : () -> () // CHECK: "zhigh.Add" // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 6 : i64} : () -> () +// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 5 : i64} : () -> () +// CHECK: "zhigh.Relu" +// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 6 : i64} : () -> () // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 5 : i64} : () -> () // CHECK: "zhigh.Unstick" // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 6 : i64} : () -> () diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir index 562193bbc4..c7857a7588 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir @@ -107,8 +107,8 @@ func.func @test_fuse_onnx_relu_conv2d(%arg0: tensor<5x3x32x32xf32>, %arg1 : tens // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "HWCK"} : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>> // CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<2xf32>) -> tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<*xf16>) -> tensor<5x2x31x31xf32> +// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>> +// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<5x2x31x31xf32> // CHECK: return [[VAR_6_]] : tensor<5x2x31x31xf32> // CHECK: } } diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir index fe3b4d1737..857baf98f6 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir @@ -79,8 +79,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias( // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32> // CHECK: return [[VAR_4_]] : tensor<4x16xf32> // CHECK: } // CHECK-NOT: "onnx.Add" @@ -105,8 +105,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias_normalized( // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32> // CHECK: return [[VAR_4_]] : tensor<4x16xf32> // CHECK: } // CHECK-NOT: "onnx.Add" diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir index 2fdd0891fb..70fbeda772 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir @@ -39,11 +39,11 @@ func.func @test_onnx_logsoftmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { // CHECK-LABEL: func @test_onnx_logsoftmax // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { -// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<*xf32> -// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16> -// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<10x10xf32> +// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<1x10x10xf32> +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x10x10xf32>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x10x10xf32>) -> tensor<10x10xf32> // CHECK: return [[VAR_4_]] : tensor<10x10xf32> // CHECK: } } @@ -57,11 +57,11 @@ func.func @test_onnx_logsoftmax_dyn(%arg0 : tensor) -> tensor<*xf32> { // CHECK-LABEL: func @test_onnx_logsoftmax_dyn // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor) -> tensor<*xf32> -// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16> -// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor +// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor) -> tensor<1x?x?xf32> +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x?x?xf32>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x?x?xf32>) -> tensor // CHECK: return [[VAR_4_]] : tensor // CHECK: } }