Skip to content

Commit

Permalink
[MLIR][ONNX] Add OnnxToTorch support for Slice Op (#2696)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-s-john authored Jan 4, 2024
1 parent 3e9bacd commit 4e5e34d
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 3 deletions.
4 changes: 3 additions & 1 deletion include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct OpBinder {

Location getLoc() { return op->getLoc(); }

int getNumOperands() { return op->getNumOperands(); }

// Operand matches of different arities.
ParseResult tensorOperand(Value &value0) {
if (op->getNumOperands() != 1)
Expand Down Expand Up @@ -189,7 +191,7 @@ struct OpBinder {
}

ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
std::string defaultValue = "") {
std::string defaultValue = "") {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
Expand Down
164 changes: 162 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> axes;
int64_t keepDims;
int64_t noop_with_empty_axes;
if (binder.tensorOperand(data) ||
binder.tensorResultType(resultType) ||
if (binder.tensorOperand(data) || binder.tensorResultType(resultType) ||
binder.s64IntegerArrayAttr(axes, "axes", 0) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
Expand Down Expand Up @@ -1092,7 +1091,168 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.replaceOp(binder.op, operand);
return success();
});
patterns.onOp(
"Slice", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultTorchType;
Value operand, starts, ends;
// Handle if axes are not provided

if (binder.tensorOperandAtIndex(operand, 0) ||
binder.tensorOperandAtIndex(starts, 1) ||
binder.tensorOperandAtIndex(ends, 2) ||
binder.tensorResultType(resultTorchType)) {
return failure();
}

auto context = rewriter.getContext();
auto operandTorchTy = operand.getType().cast<Torch::ValueTensorType>();
auto operandTy =
operandTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();

if (!operandTy)
return rewriter.notifyMatchFailure(
binder.op,
"Expected tensor operator argument to be a ranked tensor type");

auto startsTorchTy = starts.getType().cast<Torch::ValueTensorType>();
auto startsTy =
startsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
int startSize = startsTy.getDimSize(0);

auto endsTorchTy = ends.getType().cast<Torch::ValueTensorType>();
auto endsTy =
endsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
int endSize = endsTy.getDimSize(0);
auto resultTy =
resultTorchType.toBuiltinTensor().dyn_cast<RankedTensorType>();
if (!resultTy)
return rewriter.notifyMatchFailure(
binder.op, "Expected result type to be a ranked tensor type");

Location loc = binder.getLoc();

// Binding `axes` from its arguments or through a default value
Value axes;
if (binder.getNumOperands() >= 4) {
if (binder.tensorOperandAtIndex(axes, 3)) {
return failure();
}
} else {
// The default axes value is the range from 0 to the number of
// dimensions
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
auto defaultAxesType = Torch::ValueTensorType::get(
context, ArrayRef<int64_t>{operandTy.getRank()},
rewriter.getIntegerType(64, /*signed*/ 1));
Value arangeLength = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
operandTy.getRank()));
axes = rewriter.create<Torch::AtenArangeOp>(
loc, defaultAxesType, arangeLength, none, none, none, none);
}

// Binding `steps` from its arguments or through a default value
Value steps;
if (binder.getNumOperands() >= 5) {
if (binder.tensorOperandAtIndex(steps, 4)) {
return failure();
}
} else {
// The default `steps` value is a 1d tensor filled with ones with a
// size of the dimension of the operand
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
auto defaultStepsType = Torch::ValueTensorType::get(
context, ArrayRef<int64_t>{operandTy.getRank()},
rewriter.getIntegerType(64, /*signed*/ 1));
Value sizeStepInput = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
operandTy.getRank()));
Value sizeStepsInput = rewriter.create<Torch::PrimListConstructOp>(
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
sizeStepInput);
steps = rewriter.create<Torch::AtenOnesOp>(
loc, defaultStepsType, sizeStepsInput, none, none, none, none);
}

if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 &&
startSize == endSize))
return rewriter.notifyMatchFailure(
binder.op, "Expected the rank of starts and ends tensors to be 1 "
"and their dimensions to match");

auto axesTorchTy = axes.getType().cast<Torch::ValueTensorType>();
auto axesTy =
axesTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
int64_t numAxes = axesTy.getDimSize(0);

if (!(axesTy && numAxes == endSize))
return rewriter.notifyMatchFailure(
binder.op, "Axes should be the same size of starts and ends");

auto stepsTy = steps.getType()
.cast<Torch::ValueTensorType>()
.toBuiltinTensor()
.dyn_cast<RankedTensorType>();

if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0)))
return rewriter.notifyMatchFailure(
binder.op, "Steps should be the same size of starts and ends");

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

auto select = [&](Value v, Value k) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>();
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
loc,
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
ty.getOptionalDtype()),
v, zero, k);
Value item = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), sel);
return item;
};

llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) {
intermediateShape[i] = -1;
}
}
auto intermediateType = Torch::ValueTensorType::get(
context, intermediateShape, resultTorchType.getOptionalDtype());
for (int i = 0; i < numAxes; ++i) {

Value k = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc,
Torch::ValueTensorType::get(
context, ArrayRef<int64_t>{1},
rewriter.getIntegerType(64, /*signed*/ 1)),
k);

Value start = select(starts, kTensor);
Value end = select(ends, kTensor);
Value axis = select(axes, kTensor);
Value step = select(steps, kTensor);

auto sliceType = intermediateType;
if (i == numAxes - 1)
sliceType = resultTorchType;
operand = rewriter.create<Torch::AtenSliceTensorOp>(
loc, sliceType, operand, axis, start, end, step);
}

rewriter.replaceOp(binder.op, operand);
return success();
});
patterns.onOp(
"Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
Loading

0 comments on commit 4e5e34d

Please sign in to comment.