Skip to content

Commit

Permalink
[TOSA] Add legalization for fill, flip, and round (#3768)
Browse files Browse the repository at this point in the history
- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and
aten.round
- Fix torchScalarToTosaTensor function to correctly convert Torch scalar
input to TOSA tensor
- Update xfail_sets.py with new e2e results
- Update basic.mlir with LIT tests for new ops


Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41

Signed-off-by: Justin Ngo <[email protected]>
  • Loading branch information
justin-ngo-arm authored Oct 7, 2024
1 parent f4840ed commit b08d086
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 56 deletions.
211 changes: 188 additions & 23 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant");

int64_t numElem = 1;
for (int64_t dim : dshape)
numElem *= dim;

if (isa<mlir::FloatType>(dtype)) {
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
(isFloat ? doubleValue : intValue),
dshape, dtype)
.value();
tosaTensor =
tosa::getConstTensor<float>(
rewriter, op,
SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
dshape, dtype)
.value();
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
auto w = intType.getWidth();
if (w != 1 && w != 32 && w != 64)
Expand All @@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
}
bool d = isFloat ? static_cast<bool>(doubleValue)
: static_cast<bool>(intValue);
tosaTensor =
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value();
tosaTensor = tosa::getConstTensor<bool>(
rewriter, op, SmallVector<bool>(numElem, d), dshape)
.value();
} else if (w == 32) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure(
Expand All @@ -183,17 +190,19 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
}
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue);
tosaTensor =
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
tosaTensor = tosa::getConstTensor<int32_t>(
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
.value();
} else if (w == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits "
"of destination type");
}
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
tosaTensor =
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
tosaTensor = tosa::getConstTensor<int64_t>(
rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
.value();
}
} else {
return rewriter.notifyMatchFailure(op, "Usupported element type");
Expand Down Expand Up @@ -5320,7 +5329,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
};

template <typename AtenOpT>
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
Expand All @@ -5336,18 +5345,48 @@ class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
op, "Only Tensor types with static shapes are currently supported");

Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");

Value fillValueTargetTensor;
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
// Reshape value tensor to have same rank and shape as input
auto inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();

auto fillValue = adaptor.getValue();
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
if (!fillValueType)
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
auto fillValueElemTy = fillValueType.getElementType();

SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);

auto fillValueMatchedInputRankType = RankedTensorType::get(
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
fillValueElemTy);

auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), fillValueMatchedInputRankType, fillValue,
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));

fillValueTargetTensor = rewriter.create<tosa::TileOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
fillValueElemTy),
fillValueMatchedInputRankTensor.getResult(),
makeShapeTorchCompatible(outType.getShape()));
} else {
if (failed(torchScalarToTosaTensor(
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
makeShapeTorchCompatible(outType.getShape()))))
return rewriter.notifyMatchFailure(
op, "Fill value must be a scalar constant");
}
Value constOp;
if (failed(torchScalarToTosaTensor(
rewriter, op, op.getValue(), constOp, outElemTy,
makeShapeTorchCompatible(outType.getShape()))))
return rewriter.notifyMatchFailure(
op, "Supplied value must be a Scalar constant");

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
fillValueTargetTensor);

return success();
}
Expand Down Expand Up @@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
return success();
}

// Legalization for aten.flip
template <>
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
AtenFlipOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

auto self = adaptor.getSelf();

auto selfTy = dyn_cast<RankedTensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are currently supported");

SmallVector<int64_t> dims;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
return rewriter.notifyMatchFailure(
op, "Only constant dims are currently supported");

auto selfRank = selfTy.getRank();

auto resultTy = getTypeConverter()->convertType(op.getType());
Value result = self;

for (auto &dim : dims) {
dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "Not all dims are valid");

result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
static_cast<int32_t>(dim));
}

rewriter.replaceOp(op, result);
return success();
}

// Legalization for aten.round:
// Rounds elements of input to the nearest integer.
// Implements "round half to even" to break ties when a number is equidistant
// from two integers.
template <>
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
AtenRoundOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// To round to the nearest integer, we will consider the fractional part of
// the input element (= input element - integer part of element). If the
// fractional part is smaller than 0.5, round the number down. If the
// fractional part is 0.5, apply "round half to even" rule. If the fractional
// part is greater than 0.5, round up.
//
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
// res = floor(input)
// else:
// res = ceil(input)

auto self = adaptor.getSelf();

auto selfTy = dyn_cast<TensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(op, "Only tensor types supported");

auto resultTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));

auto boolTy =
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));

auto resultElemTy = resultTy.getElementType();

auto oneHalf =
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();

auto two =
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();

auto floorInput =
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);

// input - floor(input)
auto fractionalPart = rewriter.create<tosa::SubOp>(
op->getLoc(), resultTy, self, floorInput.getResult());

auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);

auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);

auto floorDivResult = rewriter.create<tosa::FloorOp>(
op->getLoc(), resultTy, floorInputDivByTwo.getResult());

// (floor(input) // 2) * 2
auto evenComparison = rewriter.create<tosa::MulOp>(
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);

// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
auto floorInputEven = rewriter.create<tosa::EqualOp>(
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());

auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);

auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());

// (frac == 0.5) && (floor(input) % 2 == 0)
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
floorInputEven.getResult());

// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
fracEqualOneHalfCond.getResult());

rewriter.replaceOpWithNewOp<tosa::SelectOp>(
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
ceilInput.getResult());

return success();
}

// Template to create supporting diagonal mask tensor for aten.diagonal
template <typename T>
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
Expand Down Expand Up @@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(

return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -6283,11 +6444,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN

#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
#define INSERT_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
#undef INSERT_FILL_SCALAR_PATTERN
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
INSERT_FILL_PATTERN(AtenFillScalarOp);
INSERT_FILL_PATTERN(AtenFillTensorOp);
#undef INSERT_FILL_PATTERN

#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
Expand Down Expand Up @@ -6359,6 +6522,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenTrilOp);
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRoundOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
Loading

0 comments on commit b08d086

Please sign in to comment.