Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle more general cases in ONNX.Expand lowering to stablehlo #2747

Merged
merged 10 commits into from
Mar 18, 2024
26 changes: 10 additions & 16 deletions src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern {
Type elementType = outputShapedType.getElementType();
int64_t outputRank = outputShapedType.getRank();

Operation *shapeDefOp = shape.getDefiningOp();
Value ones;
if (elementType.isa<IntegerType>())
ones = rewriter.create<stablehlo::ConstantOp>(
Expand All @@ -60,17 +59,8 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern {
ones = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getFloatAttr(elementType, 1.0));
Value broadcastedOnes;
if (ONNXShapeOp shapeOp = dyn_cast_or_null<ONNXShapeOp>(shapeDefOp)) {
assert(shapeOp.getData().getType().isa<ShapedType>() &&
"ShapeOp's input data should be of ShapedType");
int64_t shapeRank =
shapeOp.getData().getType().cast<ShapedType>().getRank();
SmallVector<int64_t, 4> onesShape(shapeRank, ShapedType::kDynamic);
RankedTensorType onesType = RankedTensorType::get(onesShape, elementType);
broadcastedOnes = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, onesType, ones, shape, rewriter.getI64TensorAttr({}));
} else if (mlir::ElementsAttr constShape =
getElementAttributeFromConstValue(shape)) {
if (mlir::ElementsAttr constShape =
getElementAttributeFromConstValue(shape)) {
llvm::SmallVector<int64_t, 4> shapeValues;
for (mlir::IntegerAttr element : constShape.getValues<IntegerAttr>())
shapeValues.push_back(element.getInt());
Expand All @@ -79,10 +69,14 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern {
broadcastedOnes = rewriter.create<stablehlo::BroadcastInDimOp>(
loc, broadcastedType, ones, rewriter.getI64TensorAttr({}));
} else {
assert(
false &&
"Shape argument of Expand is the output of an unexpected operation. "
"Supported operations are: onnx.Constant and onnx.Shape");
ShapedType shapeType = shape.getType().cast<ShapedType>();
assert(shapeType.getRank() == 1 && shapeType.hasStaticShape() &&
"expected 1D statically shaped shape tensor");
int64_t shapeRank = shapeType.getShape()[0];
SmallVector<int64_t, 4> onesShape(shapeRank, ShapedType::kDynamic);
RankedTensorType onesType = RankedTensorType::get(onesShape, elementType);
broadcastedOnes = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, onesType, ones, shape, rewriter.getI64TensorAttr({}));
}
llvm::SmallVector<Value, 4> newOperands = {input, broadcastedOnes};
llvm::SmallVector<Value, 4> broadcastedOperands = getBroadcastedOperands(
Expand Down
20 changes: 20 additions & 0 deletions test/mlir/conversion/onnx_to_stablehlo/Tensor/Expand.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,23 @@ func.func @test_expand_with_shape(%arg0 : tensor<2x1x6x1xf32>, %arg1: tensor<6x2
// CHECK: [[VAR_7_:%.+]] = stablehlo.multiply [[VAR_5_]], [[VAR_6_]] : tensor<2x1x6x2xf32>
// CHECK: return [[VAR_7_]] : tensor<2x1x6x2xf32>
// CHECK: }

// -----

func.func @test_expand_with_arbitrary(%arg0: tensor<2x1x6x1xf32>, %arg1: tensor<2xi64>) -> tensor<2x1x6x2xf32> {
%1 = "onnx.Expand"(%arg0, %arg1) : (tensor<2x1x6x1xf32>, tensor<2xi64>) -> tensor<2x1x6x2xf32>
return %1 : tensor<2x1x6x2xf32>
}

// CHECK-LABEL: func.func @test_expand_with_arbitrary
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x6x1xf32>, [[PARAM_1_:%.+]]: tensor<2xi64>) -> tensor<2x1x6x2xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 1, 6, 1] : tensor<4xindex>
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[VAR_2_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_1_]], [[PARAM_1_]], dims = [] : (tensor<f32>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: [[VAR_3_:%.+]] = shape.shape_of [[VAR_2_]] : tensor<?x?xf32> -> tensor<2xindex>
// CHECK: [[VAR_4_:%.+]] = shape.broadcast [[VAR_3_]], [[VAR_0_]] : tensor<2xindex>, tensor<4xindex> -> tensor<4xindex>
// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_4_]], dims = [0, 1, 2, 3] : (tensor<2x1x6x1xf32>, tensor<4xindex>) -> tensor<2x1x6x2xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_2_]], [[VAR_4_]], dims = [2, 3] : (tensor<?x?xf32>, tensor<4xindex>) -> tensor<2x1x6x2xf32>
// CHECK: [[VAR_7_:%.+]] = stablehlo.multiply [[VAR_5_]], [[VAR_6_]] : tensor<2x1x6x2xf32>
// CHECK: return [[VAR_7_]] : tensor<2x1x6x2xf32>
// CHECK: }
Loading