Skip to content

Commit

Permalink
[onert] Check Gather operand data type (#14054)
Browse files Browse the repository at this point in the history
This commit adds operation validation for Gather to check operand data type.

ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <[email protected]>
  • Loading branch information
hseok-oh authored Sep 24, 2024
1 parent 102163e commit 646c7f1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
15 changes: 15 additions & 0 deletions runtime/onert/core/src/ir/OperationValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,21 @@ void OperationValidator::visit(const operation::Fill &node)
{DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8}));
}

void OperationValidator::visit(const operation::Gather &node)
{
const auto output_index{node.getOutputs().at(0)};
const auto input_index{node.getInputs().at(operation::Gather::INPUT)};
const auto indices_index{node.getInputs().at(operation::Gather::INDICES)};

const auto input_type = operandType(input_index);
if (input_type == DataType::QUANT_GGML_Q4_0 || input_type == DataType::QUANT_GGML_Q8_0)
OP_REQUIRES(isValidType(output_index, {DataType::FLOAT32}));
else
OP_REQUIRES(isSameType(output_index, input_index));

OP_REQUIRES(isValidType(indices_index, {DataType::INT32, DataType::INT64}));
}

void OperationValidator::visit(const operation::HashtableLookup &node)
{
const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/src/ir/OperationValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class OperationValidator : public OperationVisitor
void visit(const operation::EmbeddingLookup &node) override;
void visit(const operation::ExpandDims &node) override;
void visit(const operation::Fill &node) override;
void visit(const operation::Gather &node) override;
void visit(const operation::HashtableLookup &node) override;
void visit(const operation::Pack &node) override;
void visit(const operation::Pad &node) override;
Expand Down

0 comments on commit 646c7f1

Please sign in to comment.