Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix onnx.GatherND and onnx.ScatterND issues with dynamic indices #2550

Merged

Conversation

negiyas
Copy link
Collaborator

@negiyas negiyas commented Oct 4, 2023

This PR fixes the following onnx.GatherND and onnx.ScatterND issues with dynamic indices.
With the following two lit tests, latest onnx-mlir causes internal errors while onnx-to-kernel conversion, this PR fixes them.

TODOs
[ X] Enable backend tests

// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-krnl %s -split-input-file | FileCheck %s

// COM: Test GatherND with dynamic shape
func.func @test_gather_nd_with_dynamic_shape(%arg0 : tensor<2x2xf32>, %arg1 : tensor<?x2xi64>) -> tensor<?xf32> {
  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<?x2xi64>) -> tensor<?xf32>
  "func.return"(%0) : (tensor<?xf32>) -> ()
}

// COM: Test GatherND with dynamic shape
func.func @test_scatter_nd_with_dynamic_indices(%arg0: tensor<2x1xi64>, %arg1: tensor<?x2xi64>, %arg2: tensor<2xi64>) -> tensor<2x1xi64> {
  %0 = "onnx.ScatterND"(%arg0, %arg1, %arg2) {reduction = "none"} : (tensor<2x1xi64>, tensor<?x2xi64>, tensor<2xi64>) -> tensor<2x1xi64>]
  return %0 : tensor<2x1xi64>
}

@negiyas negiyas marked this pull request as draft October 4, 2023 06:44
Signed-off-by: Yasushi Negishi <[email protected]>
@negiyas negiyas changed the title [WIP] Lit tests to check onnx.GatherND and onnx.ScatterND issues with dynamic indices [WIP] Fix onnx.GatherND and onnx.ScatterND issues with dynamic indices Oct 10, 2023
Signed-off-by: Yasushi Negishi <[email protected]>
Signed-off-by: Yasushi Negishi <[email protected]>
Signed-off-by: Yasushi Negishi <[email protected]>
@negiyas negiyas changed the title [WIP] Fix onnx.GatherND and onnx.ScatterND issues with dynamic indices Fix onnx.GatherND and onnx.ScatterND issues with dynamic indices Oct 10, 2023
@negiyas negiyas marked this pull request as ready for review October 10, 2023 04:22
Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there also a backend test that we could activate to test this feature through execution of an operation?

src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp Outdated Show resolved Hide resolved
@negiyas
Copy link
Collaborator Author

negiyas commented Oct 11, 2023

Is there also a backend test that we could activate to test this feature through execution of an operation?
Thanks for the comments. I am preparing a bachend test.

@negiyas
Copy link
Collaborator Author

negiyas commented Oct 17, 2023

Is there also a backend test that we could activate to test this feature through execution of an operation?
Thanks for the comments. I am preparing a bachend test.

I added backend tests for onnx.GatherND with dynamic indices cases. Thanks.

IndexExpr indicesDimsSize = oneIE;
for (int64_t i = 0; i < indicesRank; i++)
indicesDimsSize = indicesDimsSize * indicesDims[i];
IndexExpr BDS(batchDimsSize),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question: is this reshape valid even when indicesDimSizes is not divisible by batchDimSize * indicesLastDim ? Or is it guaranteed? Maybe you can write a comment explaining why this works here in the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexandreEichenberger Thanks for the comments.

Alex> Quick question: is this reshape valid even when indicesDimSizes is not divisible by batchDimSize * indicesLastDim ?
Alex> Or is it guaranteed? Maybe you can write a comment explaining why this works here in the code.

It is guaranteed, because IDS can be calculated by product of a part of the indices dimensions.
I changed the code to calculate IDS by using "product" instead of "floorDiv" to clarify the meaning.

Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks much clearer without the DIV operation

@negiyas negiyas merged commit ee36a16 into onnx:main Oct 26, 2023
8 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #13132 [push] Fix onnx.GatherND and on... started at 00:02

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #12125 [push] Fix onnx.GatherND and on... started at 00:09

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #13108 [push] Fix onnx.GatherND and on... started at 23:02

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #13108 [push] Fix onnx.GatherND and on... passed after 1 hr 17 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #13132 [push] Fix onnx.GatherND and on... passed after 1 hr 27 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #12125 [push] Fix onnx.GatherND and on... passed after 1 hr 36 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants