Skip to content

Commit

Permalink
Add conversion for stablehlo transpose (#375)
Browse files Browse the repository at this point in the history
Co-authored by: Simon Camphausen <[email protected]>
  • Loading branch information
lucas-camp authored Jul 18, 2023
1 parent 7297909 commit f232c2a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/stablehlo-op-coverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ The table below shows the supported StableHLO ops.
| reduce_window | :white_check_mark: | No support for dilation |
| reshape | :heavy_check_mark: | |
| select | :heavy_check_mark: | |
| transpose | :heavy_check_mark: | |
36 changes: 35 additions & 1 deletion lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,38 @@ class PadOpConversion : public OpConversionPattern<stablehlo::PadOp> {
}
};

/// Convert `stablehlo.transpose` into an `emitc.call` operation.
class TransposeOpConversion
: public OpConversionPattern<stablehlo::TransposeOp> {
using OpConversionPattern<stablehlo::TransposeOp>::OpConversionPattern;

public:
TransposeOpConversion(MLIRContext *ctx)
: OpConversionPattern<stablehlo::TransposeOp>(ctx) {}

private:
LogicalResult
matchAndRewrite(stablehlo::TransposeOp transposeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringAttr callee = rewriter.getStringAttr("emitc::stablehlo::transpose");

SmallVector<Attribute> arguments =
indexSequence(adaptor.getOperands().size(), transposeOp.getContext());

arguments.push_back(transposeOp.getPermutation());
ArrayAttr args = rewriter.getArrayAttr(arguments);

Type resultType = transposeOp.getResult().getType();
ArrayAttr templateArgs = rewriter.getArrayAttr({TypeAttr::get(resultType)});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
transposeOp, transposeOp.getType(), callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.rng` into an `emitc.call` operation.
class RngOpConversion : public OpConversionPattern<stablehlo::RngOp> {

Expand Down Expand Up @@ -552,6 +584,7 @@ void populateStablehloToEmitcPatterns(MLIRContext *ctx,
/*explicitResultType=*/true);
patterns.add<GenericOpConversion<stablehlo::SelectOp>>(
ctx, "emitc::stablehlo::select");
patterns.add<TransposeOpConversion>(ctx);

// Insert patterns for StableHLO RNG ops.
patterns.add<RngOpConversion>(ctx);
Expand Down Expand Up @@ -630,7 +663,8 @@ struct ConvertStablehloToEmitCPass
stablehlo::DotOp,
stablehlo::PadOp,
stablehlo::ReshapeOp,
stablehlo::SelectOp>();
stablehlo::SelectOp,
stablehlo::TransposeOp>();

// StableHLO RNG ops.
target.addIllegalOp<stablehlo::RngOp>();
Expand Down
21 changes: 21 additions & 0 deletions reference-implementation/include/emitc/stablehlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,27 @@ inline Src select(typename replace_element_type<bool, Src>::type pred,
return z;
}

// TransposeOp
// Maps the perms dimension from Dest to Src.
template <typename Dest, typename Src>
inline Dest transpose(Src operand, Tensor1D<int64_t, Src::rank()> perms) {
static_assert(is_tensor<Src>::value, "Expected tensor argument");
static_assert(is_tensor<Dest>::value, "Expected tensor result");

// Since emitc::broadcast_in_dim maps the dimensions (argument
// "broadcast_dimensions") from Src to Dest and stablehlo::transpose maps the
// dimensions (argument "perms") from Dest to Src, we have to invert the
// mapping.
Tensor1D<int64_t, Src::rank()> broadcast_dimensions;
for (size_t i = 0; i < perms.size(); ++i) {
auto pos = std::find(perms.begin(), perms.end(), i);
assert(pos != std::end(perms));
int64_t index = std::distance(perms.begin(), pos);
broadcast_dimensions[i] = index;
}
return emitc::broadcast_in_dim<Dest>(operand, broadcast_dimensions);
}

// RngUniformOp
template <typename Dest, typename T, size_t N>
inline Dest rng_uniform(Tensor<T> low, Tensor<T> high,
Expand Down

0 comments on commit f232c2a

Please sign in to comment.