Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix onnx.GatherND and onnx.ScatterND issues with dynamic indices #2550

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c56b620
Add test for onnx.GatherND, onnx.ScatterND and onnx.Gather ops with d…
negiyas Oct 4, 2023
5925bfe
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 4, 2023
943082e
Fix comments on the lit test
negiyas Oct 5, 2023
decbfb5
Fix a typo in lit ltest.
negiyas Oct 5, 2023
37867ee
Fix issues for onnx.GatherND and onnx.ScatterND with dynamic indices.
negiyas Oct 10, 2023
d948e11
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 10, 2023
245deea
Fix clang-format issues.
negiyas Oct 10, 2023
8592a35
Fix a compilation error.
negiyas Oct 10, 2023
327121a
Fix clang-format issues.
negiyas Oct 10, 2023
5e25925
use std::min instead of #define MIN in ScatterND.cpp
negiyas Oct 11, 2023
441981a
Add dynamic cases in backend tests for scatternd/gathernd
negiyas Oct 11, 2023
f957c49
Fixed issues with "test_gathernd_example_int32_cpu" in both static an…
negiyas Oct 13, 2023
ce76f99
Fix backend test issues for onnx.gatherND.
negiyas Oct 13, 2023
22bf4e1
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 17, 2023
259722c
Fixing lit and backend tests for onnx.GatherND.
negiyas Oct 17, 2023
51d036c
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 17, 2023
5863ccf
Fix backend test for onnx.GatherND with dynamic indices.
negiyas Oct 17, 2023
b447895
Fix bakend tests for onnx.gatherND with dynamic indices.
negiyas Oct 17, 2023
36b4c90
Fix lit tests and backend tests errors for onnx.GatherND with dynamic…
negiyas Oct 17, 2023
112c7e7
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 17, 2023
8ff9cc5
Fix python format of test/backend/inference_backend.py
negiyas Oct 17, 2023
e945fdb
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 17, 2023
21f0325
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 20, 2023
9dbfa26
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 23, 2023
f0c56b1
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 24, 2023
c54d323
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 25, 2023
42c2882
Change src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp to change how to…
negiyas Oct 25, 2023
cf837be
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 26, 2023
ffb0adf
Merge branch 'main' into issue_gather_scatter_nd_with_dynamic_indices
negiyas Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 33 additions & 47 deletions src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,68 +64,60 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern<ONNXGatherNDOp> {
Value data = adaptor.getData();
Value indices = adaptor.getIndices();
int64_t b = adaptor.getBatchDims();
auto indicesType = indices.getType().cast<ShapedType>();
DimsExpr dataDims, indicesDims;
create.krnlIE.getShapeAsDims(data, dataDims);
create.krnlIE.getShapeAsDims(indices, indicesDims);
auto dataType = data.getType().cast<ShapedType>();
ArrayRef<int64_t> indicesShape = indicesType.getShape();
ArrayRef<int64_t> dataShape = dataType.getShape();
int64_t dataRank = dataShape.size();
int64_t indicesRank = indicesShape.size();
int64_t dataRank = dataDims.size();
int64_t indicesRank = indicesDims.size();
auto indicesType = indices.getType().cast<ShapedType>();
ArrayRef<int64_t> indicesShape = indicesType.getShape();
int64_t indicesLastDim = indicesShape[indicesRank - 1];
assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
"indices.shape[-1] must be in the range [1, dataRank - b]");

// Convert the output type to MemRefType.
Type convertedType = typeConverter->convertType(*op->result_type_begin());
assert(convertedType && convertedType.isa<MemRefType>() &&
"Failed to convert type to MemRefType");
MemRefType outputMemRefType = convertedType.cast<MemRefType>();
ArrayRef<int64_t> outputShape = outputMemRefType.getShape();
int64_t outputRank = outputShape.size();

// Ensure the operation constains are satisfied.
assert(dataRank >= 1 && "The rank of 'data' must be >= 1");
assert(indicesRank >= 1 && "The rank of 'indices' must be >= 1");
assert((outputRank == dataRank + indicesRank - indicesLastDim - 1 - b) &&
"Incorrect outut rank");
assert(b >= 0 && "batch_dim should not be negative");
assert(b < std::min(dataRank, indicesRank) &&
"batch_dims must be smaller than the min(dataRank, indicesRank)");
assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
"indices.shape[-1] must be in the range [1, dataRank - b]");
DimsExpr outputDims = shapeHelper.getOutputDims();

// Reshape 'indices' to the 3D shape:
// [batchDimSize, indicesDimsSize, indices.shape[-1]].
const int64_t batchDimsSize = std::accumulate(indicesShape.begin(),
indicesShape.begin() + b, 1, std::multiplies<int64_t>());
const int64_t indicesDimsSize = std::accumulate(indicesShape.begin(),
indicesShape.end(), 1, std::multiplies<int64_t>());
assert(batchDimsSize >= 0 && "batchDimsSize must be non-negative");
assert(indicesDimsSize >= 0 && "indicesDimsSize must be non-negative");

LiteralIndexExpr BDS(batchDimsSize),
IDS(indicesDimsSize / (batchDimsSize * indicesLastDim)),
ILD(indicesLastDim);
LiteralIndexExpr oneIE(1);
IndexExpr batchDimsSize = oneIE;
for (int64_t i = 0; i < b; i++)
batchDimsSize = batchDimsSize * indicesDims[i];
IndexExpr indicesDimsSize = oneIE;
for (int64_t i = b; i < indicesRank - 1; i++)
indicesDimsSize = indicesDimsSize * indicesDims[i];
IndexExpr BDS(batchDimsSize), IDS(indicesDimsSize);
LiteralIndexExpr ILD(indicesLastDim);
DimsExpr newIndicesShape = {BDS, IDS, ILD};
Value reshapedIndices =
create.mem.reinterpretCast(indices, newIndicesShape);
LLVM_DEBUG(llvm::dbgs() << "reshapedIndices: " << reshapedIndices << "\n");

// Reshape 'data' to shape [batchDimSize, data.shape[b:]]
DimsExpr newDataShape = {BDS};
DimsExpr newDataDims = {BDS};
for (int64_t i = b; i < dataRank; ++i) {
assert(dataShape[i] != ShapedType::kDynamic &&
"Cannot support data with dynamic dimensions");
LiteralIndexExpr dataDim(dataShape[i]);
newDataShape.emplace_back(dataDim);
newDataDims.emplace_back(dataDims[i]);
}
int64_t reshapedDataRank = newDataShape.size();
Value reshapedData = create.mem.reinterpretCast(data, newDataShape);
int64_t reshapedDataRank = newDataDims.size();
Value reshapedData = create.mem.reinterpretCast(data, newDataDims);
LLVM_DEBUG(llvm::dbgs() << "reshapedData: " << reshapedData << "\n");

// Allocate a 1D output buffer.
const int64_t outputDimsSize = std::accumulate(
outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
Value outputDataBuffer = create.mem.alloc(
MemRefType::get({outputDimsSize}, outputMemRefType.getElementType()));

IndexExpr outputDimsSize = oneIE;
for (uint64_t i = 0; i < outputDims.size(); i++)
outputDimsSize = outputDimsSize * outputDims[i];
SmallVector<IndexExpr> outputIndexExpr = {outputDimsSize};
int64_t dim = outputDimsSize.isLiteral() ? outputDimsSize.getLiteral()
: ShapedType::kDynamic;
Type outputType = dataType.getElementType();
Value outputDataBuffer =
create.mem.alloc(MemRefType::get({dim}, outputType), outputIndexExpr);
// Initialize the index used to store the result values.
Value iZero = create.math.constantIndex(0);
Value iOne = create.math.constantIndex(1);
Expand Down Expand Up @@ -247,14 +239,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern<ONNXGatherNDOp> {
});

// Finally reshape 'outputDataBuffer' to the shape of the output.
DimsExpr newOutputShape;
for (int64_t dim : outputShape) {
LiteralIndexExpr outputDim(dim);
newOutputShape.emplace_back(outputDim);
}

Value reshapedOutput =
create.mem.reinterpretCast(outputDataBuffer, newOutputShape);
create.mem.reinterpretCast(outputDataBuffer, outputDims);
LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n");

rewriter.replaceOp(op, reshapedOutput);
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern<ONNXScatterNDOp> {
IndexExpr index = NonAffineIndexExpr(indexVal);
outputAccessFct.emplace_back(index);
} else {
IndexExpr index = SymbolIndexExpr(loopInd[i]);
IndexExpr index = SymbolIndexExpr(
loopInd[std::min<unsigned>(i, loopInd.size() - 1)]);
outputAccessFct.emplace_back(index);
}
}
Expand Down
13 changes: 11 additions & 2 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,10 +1023,19 @@ def get_test_models():
},
# ==OP== GatherND
# ==MIN== 11
"test_gathernd_example_int32_cpu": {STATIC_SHAPE: {}, CONSTANT_INPUT: {-1}},
"test_gathernd_example_float32_cpu": {STATIC_SHAPE: {}, CONSTANT_INPUT: {-1}},
"test_gathernd_example_int32_cpu": {
STATIC_SHAPE: {},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
"test_gathernd_example_float32_cpu": {
STATIC_SHAPE: {},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
"test_gathernd_example_int32_batch_dim1_cpu": {
STATIC_SHAPE: {},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
# ==OP== Gemm
Expand Down
58 changes: 58 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Tensor/GatherND.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,61 @@ func.func @test_gather_nd_2(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1x2xi64>
// CHECK: [[RES:%.+]] = memref.reinterpret_cast [[RES_BUFFER]] to offset: [0], sizes: [2, 1, 2], strides: [2, 2, 1] : memref<4xf32> to memref<2x1x2xf32>
// CHECK: return [[RES]] : memref<2x1x2xf32>
}

// -----

// COM: Test GatherND with dynamic shape
func.func @test_gather_nd_with_dynamic_shape_int(%arg0 : tensor<2x2xi32>, %arg1 : tensor<?x2xi64>) -> tensor<?xi32> {
%0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xi32>, tensor<?x2xi64>) -> tensor<?xi32>
"func.return"(%0) : (tensor<?xi32>) -> ()
// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @test_gather_nd_with_dynamic_shape_int
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x2xi32>, [[PARAM_1_:%.+]]: memref<?x2xi64>) -> memref<?xi32> {
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_]] : memref<?x2xi64>
// CHECK-DAG: [[CST_2_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_2_3_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_2_4_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_5_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_1_]] : memref<?x2xi64>
// CHECK-DAG: [[CST_2_5_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_2_6_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]]([[VAR_dim_5_]])
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1, [[VAR_dim_5_]], 2], strides: {{.}}[[VAR_0_]], 2, 1] : memref<?x2xi64> to memref<1x?x2xi64>
// CHECK-DAG: [[CST_1_2_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[VAR_reinterpret_cast_10_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x2xi32> to memref<1x2x2xi32>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) : memref<?xi32>
// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_1_3_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() : memref<index>
// CHECK: krnl.store [[CST_0_2_]], [[RES_1_]][] : memref<index>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_5_]])){
// CHECK-DAG: [[VAR_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[CST_0_4_:%.+]] = arith.constant 0 : index
// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[CST_0_4_]]{{.}} : memref<1x?x2xi64>
// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[LOAD_VAR_reinterpret_cast_MEM_]] : i64 to index
// CHECK-DAG: [[CST_1_4_:%.+]] = arith.constant 1 : index
// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[CST_1_4_]]{{.}} : memref<1x?x2xi64>
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[LOAD_VAR_reinterpret_cast_MEM_1_]] : i64 to index
// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_10_MEM_:%.+]] = krnl.load [[VAR_reinterpret_cast_10_]]{{.}}[[VAR_2_]]#0, [[VAR_4_]], [[VAR_6_]]{{.}} : memref<1x2x2xi32>
// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref<index>
// CHECK: krnl.store [[LOAD_VAR_reinterpret_cast_10_MEM_]], [[RES_]]{{.}}[[LOAD_RES_1_MEM_]]{{.}} : memref<?xi32>
// CHECK: [[VAR_9_:%.+]] = arith.addi [[LOAD_RES_1_MEM_]], [[CST_1_3_]] : index
// CHECK: krnl.store [[VAR_9_]], [[RES_1_]][] : memref<index>
// CHECK: }
// CHECK-DAG: [[CST_1_5_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_reinterpret_cast_15_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: [0], sizes: {{.}}[[VAR_dim_]]{{.}}, strides: [1] : memref<?xi32> to memref<?xi32>
// CHECK: return [[VAR_reinterpret_cast_15_]] : memref<?xi32>
}

29 changes: 29 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Tensor/ScatterND.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,32 @@ func.func @test_scatter_nd1(%arg0: tensor<4x4x4xf32>, %arg1: tensor<2x1xi64>, %a
// CHECK: return [[RES]] : memref<4x4x4xf32>
}

// -----

// COM: Test GatherND with dynamic shape
func.func @test_scatter_nd_with_dynamic_indices(%arg0: tensor<2x1xi64>, %arg1: tensor<?x2xi64>, %arg2: tensor<2xi64>) -> tensor<2x1xi64> {
%0 = "onnx.ScatterND"(%arg0, %arg1, %arg2) {reduction = "none"} : (tensor<2x1xi64>, tensor<?x2xi64>, tensor<2xi64>) -> tensor<2x1xi64>
return %0 : tensor<2x1xi64>
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_scatter_nd_with_dynamic_indices
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x1xi64>, [[PARAM_1_:%.+]]: memref<?x2xi64>, [[PARAM_2_:%.+]]: memref<2xi64>) -> memref<2x1xi64> {
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x1xi64>
// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : i64
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: "krnl.memcpy"([[RES_]], [[PARAM_0_]], [[CST_2_1_]], [[CST_0_]], [[CST_0_]]) : (memref<2x1xi64>, memref<2x1xi64>, i64, index, index) -> ()
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_2_2_:%.+]] = arith.constant 2 : index
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 2){
// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index
// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]], [[CST_0_2_]]{{.}} : memref<?x2xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_1_]]{{.}} : memref<2xi64>
// CHECK: krnl.store [[LOAD_PARAM_2_MEM_]], [[RES_]]{{.}}[[VAR_3_]], [[VAR_1_]]{{.}} : memref<2x1xi64>
// CHECK: }
// CHECK: return [[RES_]] : memref<2x1xi64>
// CHECK: }
}
Loading