diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index 17a939ea10..3f950d8848 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -504,16 +504,104 @@ DimAnalysis::DimAnalysis(ArrayRef vals) { DimAnalysis::DimAnalysis(ModuleOp moduleOp) { moduleOp.walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { - for (Value arg : funcOp.getArguments()) - build(arg); + // Build dimensions for function arguments and results. + buildFunctionArgsRes(funcOp); + } else { + // Build dimensions for normal operation results. + for (Value output : op->getResults()) + if (!isNoneValue(output)) + build(output); } - for (Value output : op->getResults()) - build(output); }); LLVM_DEBUG(llvm::dbgs() << "The number of dynamic dims in the IR: " << numOfDynamicDims << "\n"); } +int64_t DimAnalysis::build(DimT d, int64_t setID) { + if (setID >= 0) { + if (dimSetMap.contains(setID)) { + dimSetMap[setID].insert(d); + LLVM_DEBUG(llvm::dbgs() + << "Build a new dim(" << d.first << ", " << d.second + << ") and insert it into the existing set " << setID << "\n"); + } + } else { + setID = setCounter; + DimSetT dimSet; + dimSet.insert(d); + dimSetMap[setID] = dimSet; + setCounter++; + LLVM_DEBUG(llvm::dbgs() + << "Build a new dim(" << d.first << ", " << d.second + << ") and insert it into a new set " << setID << "\n"); + } + if (setID >= 0) + numOfDynamicDims++; + return setID; +} + +void DimAnalysis::buildFunctionArgsRes(func::FuncOp funcOp) { + // If dim_params are available, try to group dims using dim_params because + // dimensions wih the same dim_param are supposed to be the same at runtime. + + // Keep dynamic dimensions with the same dim_param. + std::map paramSetMap; + + auto buildFor = [¶mSetMap, this](ValueRange args, ArrayAttr argAttrs) { + for (size_t argPos = 0; argPos < args.size(); ++argPos) { + Value arg = args[argPos]; + auto tensorType = arg.getType().dyn_cast(); + if (!tensorType) + continue; + // Get dim_params if exists. + std::map indexParamMap; + getONNXDimParams(indexParamMap, argAttrs, argPos); + // Check and build each dynamic dimension. + for (int64_t dimPos = 0; dimPos < tensorType.getRank(); ++dimPos) { + if (!tensorType.isDynamicDim(dimPos)) + continue; + DimT ti(arg, dimPos); + if (auto dp = indexParamMap.find(dimPos); dp != indexParamMap.end()) { + // This arg has dim_param, build it later with other args of the + // same dim_param + if (paramSetMap.find(dp->second) == paramSetMap.end()) { + DimSetT dimSet; + dimSet.insert(ti); + paramSetMap[dp->second] = dimSet; + } else { + paramSetMap[dp->second].insert(ti); + } + } else { + // This arg does not have dim_param, build it now. + build(ti); + } + } + } + }; + + // Build internal mappings for arguments. + ArrayRef args = funcOp.getArguments(); + ArrayAttr argAttrs = funcOp.getArgAttrsAttr(); + buildFor(args, argAttrs); + + // Build internal mappings for results. + Operation *terminator = funcOp.getRegion().back().getTerminator(); + ValueRange resVals; + if (auto returnOp = dyn_cast(terminator)) + resVals = returnOp.getOperands(); + else if (auto returnOp = dyn_cast(terminator)) + resVals = returnOp.getOperands(); + ArrayAttr resAttrs = funcOp.getResAttrsAttr(); + buildFor(resVals, resAttrs); + + // Build dynamic dimensions using dim_param. + for (const auto &[param, dimSet] : paramSetMap) { + int64_t setID = -1; + for (DimT d : dimSet) + setID = build(d, setID); + } +} + void DimAnalysis::build(Value val) { if (auto tensorType = val.getType().dyn_cast()) { for (unsigned i = 0; i < tensorType.getRank(); ++i) { @@ -524,8 +612,10 @@ void DimAnalysis::build(Value val) { dimSet.insert(ti); dimSetMap[setCounter++] = dimSet; numOfDynamicDims++; - LLVM_DEBUG(llvm::dbgs() << "Build a new dim(" << ti.first << ", " - << ti.second << ")\n"); + LLVM_DEBUG(llvm::dbgs() + << "Build a new dim(" << ti.first << ", " << ti.second + << ") and insert it into a new set " << (setCounter - 1) + << "\n"); } } } @@ -638,6 +728,31 @@ void DimAnalysis::dump() const { } } +void DimAnalysis::getONNXDimParams( + std::map &indexParamMap, ArrayAttr argResAttr, + unsigned index) { + if (!argResAttr) + return; + if (index >= argResAttr.size()) + return; + DictionaryAttr dictAttr = llvm::dyn_cast(argResAttr[index]); + if (dictAttr && dictAttr.contains("onnx.dim_params")) { + // onnx.dim_params = dimIndex:dimParam,dimIndex:dimParam,... + StringRef dimParams = dictAttr.getNamed("onnx.dim_params") + .value() + .getValue() + .cast() + .getValue(); + SmallVector splittedDimParams; + dimParams.split(splittedDimParams, ','); + for (size_t k = 0; k < splittedDimParams.size(); ++k) { + StringRef s = splittedDimParams[k]; + std::pair indexParam = s.split(':'); + indexParamMap[stoi(indexParam.first.str())] = indexParam.second.str(); + } + } +} + void DimAnalysis::analyze() { // Build sets of the same dynamic dimensions and merge them until a fixed // point where there is no update on each set. diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.hpp b/src/Dialect/ONNX/ONNXDimAnalysis.hpp index 0eb2a29932..ec94eaa6df 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.hpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.hpp @@ -14,6 +14,7 @@ #pragma once +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Value.h" @@ -90,6 +91,16 @@ class DimAnalysis { /// Each dynamic dimension is initially assigned to a singleton set. void build(mlir::Value val); + /// Initializes the internal mappings for a single dynamic dimension. + /// The dynamic dimension is initially assigned to a newly-created set or an + /// existing set depending on `setID` is -1 or not. + /// This method returns the set ID that contains the dimension. + int64_t build(DimT d, int64_t setID = -1); + + /// Initializes the internal mappings for function arguments and resutls. + void buildFunctionArgsRes(mlir::func::FuncOp funcOp); + + // Create dims for function arguments. /// Update each set of dynamic dimensions to include the same dynamic /// dimensions. This is a local update in the sense that the search space /// includes dynamic dimensions that directly link to the dimensions in the @@ -103,6 +114,12 @@ class DimAnalysis { /// Visit a dynamic dimension and find new same dynamic dimensions. void visitDim(DimT &dim, DimSetT &sameDims) const; + /// Get onnx.dim_params value from a function argument/result and put it into + /// a map. + /// TODO: find a new home for this function. + void getONNXDimParams(std::map &indexParamMap, + mlir::ArrayAttr argResAttr, unsigned index); + private: int64_t setCounter = 0; int64_t numOfDynamicDims = 0; diff --git a/test/mlir/onnx/onnx_dim_analysis.mlir b/test/mlir/onnx/onnx_dim_analysis.mlir index 858bd104f5..6ec01096ed 100644 --- a/test/mlir/onnx/onnx_dim_analysis.mlir +++ b/test/mlir/onnx/onnx_dim_analysis.mlir @@ -1,5 +1,43 @@ // RUN: onnx-mlir-opt --onnx-dim-analysis %s -split-input-file | FileCheck %s +// Check if dim_analysis takes into account the relationship between inputs via dim_params. +func.func @test_dim_params_onnx_return(%arg0: tensor {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, %arg1: tensor {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + onnx.Return %0: tensor + +// CHECK-LABEL: func.func @test_dim_params_onnx_return +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, [[PARAM_1_:%.+]]: tensor {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor, tensor) -> tensor +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: onnx.Return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// Check if dim_analysis takes into account the relationship between inputs via dim_params. +func.func @test_dim_params_std_return(%arg0: tensor {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, %arg1: tensor {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor + +// CHECK-LABEL: func.func @test_dim_params_std_return +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, [[PARAM_1_:%.+]]: tensor {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) { +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor, tensor) -> tensor +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + // ----- // This test is an excerpt of BertSquad-12 model in the model zoo. @@ -112,15 +150,15 @@ func.func @test_binary_elementwise(%arg0 : tensor) -> tensor) -> tensor { -// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: [[VAR_0_:%.+]] = "onnx.Sigmoid"([[PARAM_0_]]) : (tensor) -> tensor -// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_0_]]) : (tensor, tensor) -> tensor -// CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: onnx.Return [[VAR_1_]] : tensor // CHECK: } @@ -288,12 +326,12 @@ func.func @test_center_crop_pad_1(%arg0: tensor, %arg1: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { // CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> // CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> // CHECK: [[VAR_3_:%.+]] = "onnx.CenterCropPad"([[PARAM_0_]], [[VAR_2_]]) {axes = [0, -2]} : (tensor, tensor<2xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: return [[VAR_3_]] : tensor // CHECK: } @@ -310,13 +348,13 @@ func.func @test_center_crop_pad_2(%arg0: tensor, %arg1: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_1_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> // CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> // CHECK: [[VAR_3_:%.+]] = "onnx.CenterCropPad"([[PARAM_0_]], [[VAR_2_]]) {axes = [-3, 2]} : (tensor, tensor<2xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 2 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 2 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK: "onnx.DimGroup"([[VAR_3_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: return [[VAR_3_]] : tensor // CHECK: } @@ -336,7 +374,7 @@ func.func @test_max_unpool(%arg0: tensor<1x1x2x2xf32>, %arg1: tensor<1x1x2x2xi64 // CHECK-LABEL: func.func @test_max_unpool // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x2x2xf32>, [[PARAM_1_:%.+]]: tensor<1x1x2x2xi64>, [[PARAM_2_:%.+]]: tensor) -> tensor { // CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 9 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 3 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_2_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> @@ -346,7 +384,7 @@ func.func @test_max_unpool(%arg0: tensor<1x1x2x2xf32>, %arg1: tensor<1x1x2x2xi64 // CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> // CHECK: [[VAR_5_:%.+]] = "onnx.MaxUnpool"([[PARAM_0_]], [[PARAM_1_]], [[VAR_4_]]) {kernel_shape = [2, 2], strides = [2, 2]} : (tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>, tensor<4xi64>) -> tensor // CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 3 : si64, group_id = 7 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 1 : si64, group_id = 9 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: return [[VAR_5_]] : tensor @@ -432,13 +470,13 @@ func.func @test_concat_input_dims(%arg0: tensor, %arg1: tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 7 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 1 : si64, group_id = 10 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[PARAM_2_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor) -> () // CHECK: [[VAR_0_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () -// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 2 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 13 : si64} : (tensor) -> () +// CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 2 : si64, group_id = 14 : si64} : (tensor) -> () // CHECK: return [[VAR_0_]] : tensor // CHECK: } }