From 4e5e34d215fa00912a8205a1d0406ee5719003a7 Mon Sep 17 00:00:00 2001 From: John Wu Date: Wed, 3 Jan 2024 19:41:10 -0800 Subject: [PATCH] [MLIR][ONNX] Add OnnxToTorch support for Slice Op (#2696) --- .../Conversion/TorchOnnxToTorch/Patterns.h | 4 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 164 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 129 ++++++++++++++ 3 files changed, 294 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 06bbb1ac526f..1ce381005fcc 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -33,6 +33,8 @@ struct OpBinder { Location getLoc() { return op->getLoc(); } + int getNumOperands() { return op->getNumOperands(); } + // Operand matches of different arities. ParseResult tensorOperand(Value &value0) { if (op->getNumOperands() != 1) @@ -189,7 +191,7 @@ struct OpBinder { } ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, - std::string defaultValue = "") { + std::string defaultValue = "") { SmallString<64> name("torch.onnx."); name.append(nameSuffix); auto attr = op->getAttr(name); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 4756315c800e..f0e11ad1cd13 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -643,8 +643,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector axes; int64_t keepDims; int64_t noop_with_empty_axes; - if (binder.tensorOperand(data) || - binder.tensorResultType(resultType) || + if (binder.tensorOperand(data) || binder.tensorResultType(resultType) || binder.s64IntegerArrayAttr(axes, "axes", 0) || binder.s64IntegerAttr(keepDims, "keepdims", 1) || binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", @@ -1092,7 +1091,168 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, operand); return success(); }); + patterns.onOp( + "Slice", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultTorchType; + Value operand, starts, ends; + // Handle if axes are not provided + + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorOperandAtIndex(starts, 1) || + binder.tensorOperandAtIndex(ends, 2) || + binder.tensorResultType(resultTorchType)) { + return failure(); + } + + auto context = rewriter.getContext(); + auto operandTorchTy = operand.getType().cast(); + auto operandTy = + operandTorchTy.toBuiltinTensor().dyn_cast(); + + if (!operandTy) + return rewriter.notifyMatchFailure( + binder.op, + "Expected tensor operator argument to be a ranked tensor type"); + + auto startsTorchTy = starts.getType().cast(); + auto startsTy = + startsTorchTy.toBuiltinTensor().dyn_cast(); + int startSize = startsTy.getDimSize(0); + + auto endsTorchTy = ends.getType().cast(); + auto endsTy = + endsTorchTy.toBuiltinTensor().dyn_cast(); + int endSize = endsTy.getDimSize(0); + auto resultTy = + resultTorchType.toBuiltinTensor().dyn_cast(); + if (!resultTy) + return rewriter.notifyMatchFailure( + binder.op, "Expected result type to be a ranked tensor type"); + + Location loc = binder.getLoc(); + + // Binding `axes` from its arguments or through a default value + Value axes; + if (binder.getNumOperands() >= 4) { + if (binder.tensorOperandAtIndex(axes, 3)) { + return failure(); + } + } else { + // The default axes value is the range from 0 to the number of + // dimensions + Value none = rewriter.create(loc); + auto defaultAxesType = Torch::ValueTensorType::get( + context, ArrayRef{operandTy.getRank()}, + rewriter.getIntegerType(64, /*signed*/ 1)); + Value arangeLength = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + operandTy.getRank())); + axes = rewriter.create( + loc, defaultAxesType, arangeLength, none, none, none, none); + } + + // Binding `steps` from its arguments or through a default value + Value steps; + if (binder.getNumOperands() >= 5) { + if (binder.tensorOperandAtIndex(steps, 4)) { + return failure(); + } + } else { + // The default `steps` value is a 1d tensor filled with ones with a + // size of the dimension of the operand + Value none = rewriter.create(loc); + auto defaultStepsType = Torch::ValueTensorType::get( + context, ArrayRef{operandTy.getRank()}, + rewriter.getIntegerType(64, /*signed*/ 1)); + Value sizeStepInput = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + operandTy.getRank())); + Value sizeStepsInput = rewriter.create( + loc, + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + sizeStepInput); + steps = rewriter.create( + loc, defaultStepsType, sizeStepsInput, none, none, none, none); + } + if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 && + startSize == endSize)) + return rewriter.notifyMatchFailure( + binder.op, "Expected the rank of starts and ends tensors to be 1 " + "and their dimensions to match"); + + auto axesTorchTy = axes.getType().cast(); + auto axesTy = + axesTorchTy.toBuiltinTensor().dyn_cast(); + int64_t numAxes = axesTy.getDimSize(0); + + if (!(axesTy && numAxes == endSize)) + return rewriter.notifyMatchFailure( + binder.op, "Axes should be the same size of starts and ends"); + + auto stepsTy = steps.getType() + .cast() + .toBuiltinTensor() + .dyn_cast(); + + if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0))) + return rewriter.notifyMatchFailure( + binder.op, "Steps should be the same size of starts and ends"); + + Value zero = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + auto select = [&](Value v, Value k) -> Value { + auto ty = v.getType().cast(); + auto sel = rewriter.create( + loc, + Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, + ty.getOptionalDtype()), + v, zero, k); + Value item = rewriter.create( + loc, rewriter.getType(), sel); + return item; + }; + + llvm::SmallVector intermediateShape(operandTy.getShape()); + for (int i = 0, s = operandTy.getRank(); i < s; ++i) { + if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) { + intermediateShape[i] = -1; + } + } + auto intermediateType = Torch::ValueTensorType::get( + context, intermediateShape, resultTorchType.getOptionalDtype()); + for (int i = 0; i < numAxes; ++i) { + + Value k = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value kTensor = rewriter.create( + loc, + Torch::ValueTensorType::get( + context, ArrayRef{1}, + rewriter.getIntegerType(64, /*signed*/ 1)), + k); + + Value start = select(starts, kTensor); + Value end = select(ends, kTensor); + Value axis = select(axes, kTensor); + Value step = select(steps, kTensor); + + auto sliceType = intermediateType; + if (i == numAxes - 1) + sliceType = resultTorchType; + operand = rewriter.create( + loc, sliceType, operand, axis, start, end, step); + } + + rewriter.replaceOp(binder.op, operand); + return success(); + }); patterns.onOp( "Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index b2a19334ab09..91421d944129 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -863,6 +863,135 @@ func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) // ----- +// CHECK-LABEL: func.func @test_slice +func.func @test_slice(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>, %arg3: !torch.vtensor<[2],si64>, %arg4: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg4, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,10,5],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg4, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[?,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,10,5],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,5],f32> + return %0 : !torch.vtensor<[3,10,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_slice_default_axes_and_slices +func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[NONE_1:.*]] = torch.constant.none + //CHECK: %[[AXES_DEFAULT_SIZE:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_AXES:.*]] = torch.aten.arange %[[AXES_DEFAULT_SIZE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[NONE_2:.*]] = torch.constant.none + //CHECK: %[[DEFAULT_SIZE_AMOUNT:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list + //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_1:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_2:.*]] = torch.constant.int 2 + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_2:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[AXES_ELEMENT_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} + +// CHECK-LABEL: func.func @test_slice_default_steps +func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[NONE:.*]] = torch.constant.none + //CHECK: %[[DEFAULT_SIZE_AMOUNT:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list + //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_1:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 2 + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_2:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[AXES_ELEMENT_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} // CHECK-LABEL: func.func @test_reshape_negative_dim func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0