Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dim_params in dynamic dimension analysis #2620

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 121 additions & 6 deletions src/Dialect/ONNX/ONNXDimAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,16 +504,104 @@ DimAnalysis::DimAnalysis(ArrayRef<Value> vals) {
DimAnalysis::DimAnalysis(ModuleOp moduleOp) {
moduleOp.walk([&](Operation *op) {
if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
for (Value arg : funcOp.getArguments())
build(arg);
// Build dimensions for function arguments and results.
buildFunctionArgsRes(funcOp);
} else {
// Build dimensions for normal operation results.
for (Value output : op->getResults())
if (!isNoneValue(output))
build(output);
}
for (Value output : op->getResults())
build(output);
});
LLVM_DEBUG(llvm::dbgs() << "The number of dynamic dims in the IR: "
<< numOfDynamicDims << "\n");
}

int64_t DimAnalysis::build(DimT d, int64_t setID) {
if (setID >= 0) {
if (dimSetMap.contains(setID)) {
dimSetMap[setID].insert(d);
LLVM_DEBUG(llvm::dbgs()
<< "Build a new dim(" << d.first << ", " << d.second
<< ") and insert it into the existing set " << setID << "\n");
}
} else {
setID = setCounter;
DimSetT dimSet;
dimSet.insert(d);
dimSetMap[setID] = dimSet;
setCounter++;
LLVM_DEBUG(llvm::dbgs()
<< "Build a new dim(" << d.first << ", " << d.second
<< ") and insert it into a new set " << setID << "\n");
}
if (setID >= 0)
numOfDynamicDims++;
return setID;
}

void DimAnalysis::buildFunctionArgsRes(func::FuncOp funcOp) {
// If dim_params are available, try to group dims using dim_params because
// dimensions wih the same dim_param are supposed to be the same at runtime.

// Keep dynamic dimensions with the same dim_param.
std::map<std::string, DimSetT> paramSetMap;

auto buildFor = [&paramSetMap, this](ValueRange args, ArrayAttr argAttrs) {
for (size_t argPos = 0; argPos < args.size(); ++argPos) {
Value arg = args[argPos];
auto tensorType = arg.getType().dyn_cast<RankedTensorType>();
if (!tensorType)
continue;
// Get dim_params if exists.
std::map<unsigned, std::string> indexParamMap;
getONNXDimParams(indexParamMap, argAttrs, argPos);
// Check and build each dynamic dimension.
for (int64_t dimPos = 0; dimPos < tensorType.getRank(); ++dimPos) {
if (!tensorType.isDynamicDim(dimPos))
continue;
DimT ti(arg, dimPos);
if (auto dp = indexParamMap.find(dimPos); dp != indexParamMap.end()) {
// This arg has dim_param, build it later with other args of the
// same dim_param
if (paramSetMap.find(dp->second) == paramSetMap.end()) {
DimSetT dimSet;
dimSet.insert(ti);
paramSetMap[dp->second] = dimSet;
} else {
paramSetMap[dp->second].insert(ti);
}
} else {
// This arg does not have dim_param, build it now.
build(ti);
}
}
}
};

// Build internal mappings for arguments.
ArrayRef<BlockArgument> args = funcOp.getArguments();
ArrayAttr argAttrs = funcOp.getArgAttrsAttr();
buildFor(args, argAttrs);

// Build internal mappings for results.
Operation *terminator = funcOp.getRegion().back().getTerminator();
ValueRange resVals;
if (auto returnOp = dyn_cast<func::ReturnOp>(terminator))
resVals = returnOp.getOperands();
else if (auto returnOp = dyn_cast<ONNXReturnOp>(terminator))
resVals = returnOp.getOperands();
ArrayAttr resAttrs = funcOp.getResAttrsAttr();
buildFor(resVals, resAttrs);

// Build dynamic dimensions using dim_param.
for (const auto &[param, dimSet] : paramSetMap) {
int64_t setID = -1;
for (DimT d : dimSet)
setID = build(d, setID);
}
}

void DimAnalysis::build(Value val) {
if (auto tensorType = val.getType().dyn_cast<RankedTensorType>()) {
for (unsigned i = 0; i < tensorType.getRank(); ++i) {
Expand All @@ -524,8 +612,10 @@ void DimAnalysis::build(Value val) {
dimSet.insert(ti);
dimSetMap[setCounter++] = dimSet;
numOfDynamicDims++;
LLVM_DEBUG(llvm::dbgs() << "Build a new dim(" << ti.first << ", "
<< ti.second << ")\n");
LLVM_DEBUG(llvm::dbgs()
<< "Build a new dim(" << ti.first << ", " << ti.second
<< ") and insert it into a new set " << (setCounter - 1)
<< "\n");
}
}
}
Expand Down Expand Up @@ -638,6 +728,31 @@ void DimAnalysis::dump() const {
}
}

void DimAnalysis::getONNXDimParams(
std::map<unsigned, std::string> &indexParamMap, ArrayAttr argResAttr,
unsigned index) {
if (!argResAttr)
return;
if (index >= argResAttr.size())
return;
DictionaryAttr dictAttr = llvm::dyn_cast<DictionaryAttr>(argResAttr[index]);
if (dictAttr && dictAttr.contains("onnx.dim_params")) {
// onnx.dim_params = dimIndex:dimParam,dimIndex:dimParam,...
StringRef dimParams = dictAttr.getNamed("onnx.dim_params")
.value()
.getValue()
.cast<StringAttr>()
.getValue();
SmallVector<StringRef, 4> splittedDimParams;
dimParams.split(splittedDimParams, ',');
for (size_t k = 0; k < splittedDimParams.size(); ++k) {
StringRef s = splittedDimParams[k];
std::pair<StringRef, StringRef> indexParam = s.split(':');
indexParamMap[stoi(indexParam.first.str())] = indexParam.second.str();
}
}
}

void DimAnalysis::analyze() {
// Build sets of the same dynamic dimensions and merge them until a fixed
// point where there is no update on each set.
Expand Down
17 changes: 17 additions & 0 deletions src/Dialect/ONNX/ONNXDimAnalysis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"

Expand Down Expand Up @@ -90,6 +91,16 @@ class DimAnalysis {
/// Each dynamic dimension is initially assigned to a singleton set.
void build(mlir::Value val);

/// Initializes the internal mappings for a single dynamic dimension.
/// The dynamic dimension is initially assigned to a newly-created set or an
/// existing set depending on `setID` is -1 or not.
/// This method returns the set ID that contains the dimension.
int64_t build(DimT d, int64_t setID = -1);

/// Initializes the internal mappings for function arguments and resutls.
void buildFunctionArgsRes(mlir::func::FuncOp funcOp);

// Create dims for function arguments.
/// Update each set of dynamic dimensions to include the same dynamic
/// dimensions. This is a local update in the sense that the search space
/// includes dynamic dimensions that directly link to the dimensions in the
Expand All @@ -103,6 +114,12 @@ class DimAnalysis {
/// Visit a dynamic dimension and find new same dynamic dimensions.
void visitDim(DimT &dim, DimSetT &sameDims) const;

/// Get onnx.dim_params value from a function argument/result and put it into
/// a map.
/// TODO: find a new home for this function.
void getONNXDimParams(std::map<unsigned, std::string> &indexParamMap,
mlir::ArrayAttr argResAttr, unsigned index);

private:
int64_t setCounter = 0;
int64_t numOfDynamicDims = 0;
Expand Down
66 changes: 52 additions & 14 deletions test/mlir/onnx/onnx_dim_analysis.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,43 @@
// RUN: onnx-mlir-opt --onnx-dim-analysis %s -split-input-file | FileCheck %s

// Check if dim_analysis takes into account the relationship between inputs via dim_params.
func.func @test_dim_params_onnx_return(%arg0: tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, %arg1: tensor<?x?xf32> {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
onnx.Return %0: tensor<?x?xf32>

// CHECK-LABEL: func.func @test_dim_params_onnx_return
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, [[PARAM_1_:%.+]]: tensor<?x?xf32> {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK: onnx.Return [[VAR_0_]] : tensor<?x?xf32>
// CHECK: }
}

// -----

// Check if dim_analysis takes into account the relationship between inputs via dim_params.
func.func @test_dim_params_std_return(%arg0: tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, %arg1: tensor<?x?xf32> {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0: tensor<?x?xf32>

// CHECK-LABEL: func.func @test_dim_params_std_return
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, [[PARAM_1_:%.+]]: tensor<?x?xf32> {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor<?x?xf32> {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK: return [[VAR_0_]] : tensor<?x?xf32>
// CHECK: }
}

// -----

// This test is an excerpt of BertSquad-12 model in the model zoo.
Expand Down Expand Up @@ -112,15 +150,15 @@ func.func @test_binary_elementwise(%arg0 : tensor<?x3x?xf32>) -> tensor<?x3x?xf3

// CHECK-LABEL: func.func @test_binary_elementwise
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x3x?xf32>) -> tensor<?x3x?xf32> {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor<?x3x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor<?x3x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x3x?xf32>) -> ()

// CHECK: [[VAR_0_:%.+]] = "onnx.Sigmoid"([[PARAM_0_]]) : (tensor<?x3x?xf32>) -> tensor<?x3x?xf32>
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor<?x3x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor<?x3x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x3x?xf32>) -> ()

// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_0_]]) : (tensor<?x3x?xf32>, tensor<?x3x?xf32>) -> tensor<?x3x?xf32>
// CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor<?x3x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor<?x3x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x3x?xf32>) -> ()
// CHECK: onnx.Return [[VAR_1_]] : tensor<?x3x?xf32>
// CHECK: }
Expand Down Expand Up @@ -288,12 +326,12 @@ func.func @test_center_crop_pad_1(%arg0: tensor<?x?x8xf32>, %arg1: tensor<?x?xf3
// CHECK-LABEL: func.func @test_center_crop_pad_1
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x8xf32>, [[PARAM_1_:%.+]]: tensor<?x?xf32>) -> tensor<?x?x8xf32> {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 0 : si64} : (tensor<?x?xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 1 : si64} : (tensor<?x?xf32>) -> tensor<1xi64>
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK: [[VAR_3_:%.+]] = "onnx.CenterCropPad"([[PARAM_0_]], [[VAR_2_]]) {axes = [0, -2]} : (tensor<?x?x8xf32>, tensor<2xi64>) -> tensor<?x?x8xf32>
// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?x8xf32>) -> ()
// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor<?x?x8xf32>) -> ()
// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?x8xf32>) -> ()
// CHECK: return [[VAR_3_]] : tensor<?x?x8xf32>
// CHECK: }
Expand All @@ -310,13 +348,13 @@ func.func @test_center_crop_pad_2(%arg0: tensor<?x8x?xf32>, %arg1: tensor<?x?xf3

// CHECK-LABEL: func.func @test_center_crop_pad_2
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x8x?xf32>, [[PARAM_1_:%.+]]: tensor<?x?xf32>) -> tensor<?x8x?xf32> {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?xf32>) -> ()
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 0 : si64} : (tensor<?x?xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 1 : si64} : (tensor<?x?xf32>) -> tensor<1xi64>
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK: [[VAR_3_:%.+]] = "onnx.CenterCropPad"([[PARAM_0_]], [[VAR_2_]]) {axes = [-3, 2]} : (tensor<?x8x?xf32>, tensor<2xi64>) -> tensor<?x8x?xf32>
// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor<?x8x?xf32>) -> ()
// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor<?x8x?xf32>) -> ()
// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x8x?xf32>) -> ()
// CHECK: return [[VAR_3_]] : tensor<?x8x?xf32>
// CHECK: }
Expand All @@ -336,7 +374,7 @@ func.func @test_max_unpool(%arg0: tensor<1x1x2x2xf32>, %arg1: tensor<1x1x2x2xi64
// CHECK-LABEL: func.func @test_max_unpool
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x2x2xf32>, [[PARAM_1_:%.+]]: tensor<1x1x2x2xi64>, [[PARAM_2_:%.+]]: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 9 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 3 : si64, group_id = 7 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_2_]]) {axis = 0 : si64} : (tensor<?x?x?x?xf32>) -> tensor<1xi64>
Expand All @@ -346,7 +384,7 @@ func.func @test_max_unpool(%arg0: tensor<1x1x2x2xf32>, %arg1: tensor<1x1x2x2xi64
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
// CHECK: [[VAR_5_:%.+]] = "onnx.MaxUnpool"([[PARAM_0_]], [[PARAM_1_]], [[VAR_4_]]) {kernel_shape = [2, 2], strides = [2, 2]} : (tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 3 : si64, group_id = 7 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 1 : si64, group_id = 9 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?x?x?xf32>) -> ()
// CHECK: return [[VAR_5_]] : tensor<?x?x?x?xf32>
Expand Down Expand Up @@ -432,13 +470,13 @@ func.func @test_concat_input_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?x
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 10 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK: [[VAR_0_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 13 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor<?x?x?xf32>) -> ()
// CHECK: return [[VAR_0_]] : tensor<?x?x?xf32>
// CHECK: }
}
Expand Down
Loading