diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp index cfbed51a18..936a5167c1 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp @@ -51,7 +51,6 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern { Type elementType = outputShapedType.getElementType(); int64_t outputRank = outputShapedType.getRank(); - Operation *shapeDefOp = shape.getDefiningOp(); Value ones; if (elementType.isa()) ones = rewriter.create( @@ -60,17 +59,8 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern { ones = rewriter.create( loc, rewriter.getFloatAttr(elementType, 1.0)); Value broadcastedOnes; - if (ONNXShapeOp shapeOp = dyn_cast_or_null(shapeDefOp)) { - assert(shapeOp.getData().getType().isa() && - "ShapeOp's input data should be of ShapedType"); - int64_t shapeRank = - shapeOp.getData().getType().cast().getRank(); - SmallVector onesShape(shapeRank, ShapedType::kDynamic); - RankedTensorType onesType = RankedTensorType::get(onesShape, elementType); - broadcastedOnes = rewriter.create( - loc, onesType, ones, shape, rewriter.getI64TensorAttr({})); - } else if (mlir::ElementsAttr constShape = - getElementAttributeFromConstValue(shape)) { + if (mlir::ElementsAttr constShape = + getElementAttributeFromConstValue(shape)) { llvm::SmallVector shapeValues; for (mlir::IntegerAttr element : constShape.getValues()) shapeValues.push_back(element.getInt()); @@ -79,10 +69,14 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern { broadcastedOnes = rewriter.create( loc, broadcastedType, ones, rewriter.getI64TensorAttr({})); } else { - assert( - false && - "Shape argument of Expand is the output of an unexpected operation. " - "Supported operations are: onnx.Constant and onnx.Shape"); + ShapedType shapeType = shape.getType().cast(); + assert(shapeType.getRank() == 1 && shapeType.hasStaticShape() && + "expected 1D statically shaped shape tensor"); + int64_t shapeRank = shapeType.getShape()[0]; + SmallVector onesShape(shapeRank, ShapedType::kDynamic); + RankedTensorType onesType = RankedTensorType::get(onesShape, elementType); + broadcastedOnes = rewriter.create( + loc, onesType, ones, shape, rewriter.getI64TensorAttr({})); } llvm::SmallVector newOperands = {input, broadcastedOnes}; llvm::SmallVector broadcastedOperands = getBroadcastedOperands( diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Expand.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Expand.mlir index c33092cc01..8a75ebce9c 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Expand.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Expand.mlir @@ -61,3 +61,23 @@ func.func @test_expand_with_shape(%arg0 : tensor<2x1x6x1xf32>, %arg1: tensor<6x2 // CHECK: [[VAR_7_:%.+]] = stablehlo.multiply [[VAR_5_]], [[VAR_6_]] : tensor<2x1x6x2xf32> // CHECK: return [[VAR_7_]] : tensor<2x1x6x2xf32> // CHECK: } + +// ----- + + func.func @test_expand_with_arbitrary(%arg0: tensor<2x1x6x1xf32>, %arg1: tensor<2xi64>) -> tensor<2x1x6x2xf32> { + %1 = "onnx.Expand"(%arg0, %arg1) : (tensor<2x1x6x1xf32>, tensor<2xi64>) -> tensor<2x1x6x2xf32> + return %1 : tensor<2x1x6x2xf32> + } + +// CHECK-LABEL: func.func @test_expand_with_arbitrary +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x6x1xf32>, [[PARAM_1_:%.+]]: tensor<2xi64>) -> tensor<2x1x6x2xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 1, 6, 1] : tensor<4xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: [[VAR_2_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_1_]], [[PARAM_1_]], dims = [] : (tensor, tensor<2xi64>) -> tensor +// CHECK: [[VAR_3_:%.+]] = shape.shape_of [[VAR_2_]] : tensor -> tensor<2xindex> +// CHECK: [[VAR_4_:%.+]] = shape.broadcast [[VAR_3_]], [[VAR_0_]] : tensor<2xindex>, tensor<4xindex> -> tensor<4xindex> +// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_4_]], dims = [0, 1, 2, 3] : (tensor<2x1x6x1xf32>, tensor<4xindex>) -> tensor<2x1x6x2xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_2_]], [[VAR_4_]], dims = [2, 3] : (tensor, tensor<4xindex>) -> tensor<2x1x6x2xf32> +// CHECK: [[VAR_7_:%.+]] = stablehlo.multiply [[VAR_5_]], [[VAR_6_]] : tensor<2x1x6x2xf32> +// CHECK: return [[VAR_7_]] : tensor<2x1x6x2xf32> +// CHECK: }