diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt index 2e50d0797b..d6f0af7e6a 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt @@ -14,6 +14,7 @@ add_onnx_mlir_library(OMONNXToZHigh OMNNPACompilerOptions OMONNXOps OMONNXToKrnl + OMShapeInferencePass OMZHighOps ACCEL_INCLUDE_DIRS PRIVATE diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp index 82be5e277b..bbb39102c7 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp @@ -196,7 +196,7 @@ void DevicePlacementPass::runOnOperation() { // Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh. // E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv. RewritePatternSet Patterns2(context); - getONNXToZHighOneOpPatterns(Patterns2); + getONNXToZHighMultipleOpPatterns(Patterns2); (void)applyAnalysisConversion(module, target, std::move(Patterns2), ConversionConfig{.legalizableOps = &legalizedOps2}); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp index 6560388bce..59d0e997e8 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp @@ -22,6 +22,7 @@ #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp" using namespace mlir; @@ -328,16 +329,13 @@ void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) { patterns.insert(context); patterns.insert(context); patterns.insert(context); + // Shape inference for newly-added operations. + getShapeInferencePatterns(patterns); } void ONNXToZHighLoweringPass::runOnOperation() { ModuleOp module = getOperation(); - // Run the unknown dimension analysis to help check equality of unknown - // dimensions at compile time. - onnx_mlir::DimAnalysis dimAnalysis(module); - dimAnalysis.analyze(); - // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); @@ -363,6 +361,11 @@ void ONNXToZHighLoweringPass::runOnOperation() { // It's ok to fail. (void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns)); + // Run the unknown dimension analysis to help check equality of unknown + // dimensions at compile time. + onnx_mlir::DimAnalysis dimAnalysis(module); + dimAnalysis.analyze(); + // Single ONNX to ZHigh operation lowering. RewritePatternSet patterns(&getContext()); onnx_mlir::getONNXToZHighOneOpPatterns(patterns); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp index 1eb25c409c..60e11ca41d 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp @@ -59,6 +59,8 @@ void ZHighToONNXLoweringPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); + zhigh::ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext()); + zhigh::ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext()); (void)applyPatternsAndFoldGreedily(function, std::move(patterns)); } diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td index 02a06cabe8..3012c2160c 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td @@ -37,31 +37,52 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create($_loc, $0.getT // ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighAddPattern : Pat< - (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (ONNXAddOp $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighAddPattern1 : Pat< + (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), $y)), + (ONNXAddOp $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)] >; +def replaceZHighAddPattern2 : Pat< + (ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_))), + (ONNXAddOp (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)] +>; //===----------------------------------------------------------------------===// // ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMulPattern : Pat< - (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (ONNXMulOp $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighMulPattern1 : Pat< + (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), $y)), + (ONNXMulOp $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighMulPattern2 : Pat< + (ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_))), + (ONNXMulOp (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [], + (addBenefit 0) >; //===----------------------------------------------------------------------===// // ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighSubPattern : Pat< - (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (ONNXSubOp $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighSubPattern1 : Pat< + (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), $y)), + (ONNXSubOp $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighSubPattern2 : Pat< + (ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_))), + (ONNXSubOp (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], + (addBenefit 0) >; //===----------------------------------------------------------------------===// @@ -69,30 +90,54 @@ def replaceZHighSubPattern : Pat< // %X),(ZHighStickOp %Y)) // Note: turn off this pattern since NNPA is faster at this moment. //===----------------------------------------------------------------------===// -// def replaceZHighDivPattern : Pat< -// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), -// (ONNXDivOp $x, $y), -// [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] -// >; +//def replaceZHighDivPattern1 : Pat< +// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)), +// (ONNXDivOp $x, (ZHighUnstickOp $y)), +// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], +// (addBenefit 1) +//>; +// +//def replaceZHighDivPattern2 : Pat< +// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))), +// (ONNXDivOp (ZHighUnstickOp $x), $y), +// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], +// (addBenefit 0) +//>; //===----------------------------------------------------------------------===// // ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMinPattern : Pat< - (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (CreateONNXMinOp $u, $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighMinPattern1 : Pat< + (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), $y)), + (CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighMinPattern2 : Pat< + (ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_))), + (CreateONNXMinOp $u, (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], + (addBenefit 0) >; //===----------------------------------------------------------------------===// // ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMaxPattern : Pat< - (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))), - (CreateONNXMaxOp $u, $x, $y), - [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)] +def replaceZHighMaxPattern1 : Pat< + (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), $y)), + (CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)), + [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], + (addBenefit 1) +>; + +def replaceZHighMaxPattern2 : Pat< + (ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_))), + (CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y), + [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], + (addBenefit 0) >; //===----------------------------------------------------------------------===// diff --git a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir index f85e8021e7..059c3bcfb8 100644 --- a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir +++ b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir @@ -5,7 +5,8 @@ func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x1xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) {onnx_node_name = "onnx.Add1"} : (tensor<10x10xf32>, tensor<10x1xf32>) -> tensor<*xf32> %1 = "onnx.Add"(%arg0, %0) {onnx_node_name = "onnx.Add2"} : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32> - "onnx.Return"(%1) : (tensor<*xf32>) -> () + %2 = "onnx.Relu"(%1) {onnx_node_name = "onnx.Relu"} : (tensor<*xf32>) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () } // CHECK-LABEL: func.func @test_instrument_add_onnx_zhigh @@ -21,6 +22,9 @@ func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : ten // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 5 : i64} : () -> () // CHECK: "zhigh.Add" // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 6 : i64} : () -> () +// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 5 : i64} : () -> () +// CHECK: "zhigh.Relu" +// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 6 : i64} : () -> () // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 5 : i64} : () -> () // CHECK: "zhigh.Unstick" // CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 6 : i64} : () -> () diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir index 562193bbc4..c7857a7588 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir @@ -107,8 +107,8 @@ func.func @test_fuse_onnx_relu_conv2d(%arg0: tensor<5x3x32x32xf32>, %arg1 : tens // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "HWCK"} : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>> // CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<2xf32>) -> tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<*xf16>) -> tensor<5x2x31x31xf32> +// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>> +// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<5x2x31x31xf32> // CHECK: return [[VAR_6_]] : tensor<5x2x31x31xf32> // CHECK: } } diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir index fe3b4d1737..857baf98f6 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir @@ -79,8 +79,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias( // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32> // CHECK: return [[VAR_4_]] : tensor<4x16xf32> // CHECK: } // CHECK-NOT: "onnx.Add" @@ -105,8 +105,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias_normalized( // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32> // CHECK: return [[VAR_4_]] : tensor<4x16xf32> // CHECK: } // CHECK-NOT: "onnx.Add" diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir index 2fdd0891fb..70fbeda772 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir @@ -39,11 +39,11 @@ func.func @test_onnx_logsoftmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { // CHECK-LABEL: func @test_onnx_logsoftmax // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { -// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<*xf32> -// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16> -// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<10x10xf32> +// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<1x10x10xf32> +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x10x10xf32>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x10x10xf32>) -> tensor<10x10xf32> // CHECK: return [[VAR_4_]] : tensor<10x10xf32> // CHECK: } } @@ -57,11 +57,11 @@ func.func @test_onnx_logsoftmax_dyn(%arg0 : tensor) -> tensor<*xf32> { // CHECK-LABEL: func @test_onnx_logsoftmax_dyn // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor) -> tensor<*xf32> -// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16> -// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor +// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor) -> tensor<1x?x?xf32> +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x?x?xf32>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x?x?xf32>) -> tensor // CHECK: return [[VAR_4_]] : tensor // CHECK: } }