From f232c2a945cff7aeddad71d9e7b0c33479b34878 Mon Sep 17 00:00:00 2001 From: Lucas Camphausen Date: Tue, 18 Jul 2023 11:51:07 +0200 Subject: [PATCH] Add conversion for stablehlo transpose (#375) Co-authored by: Simon Camphausen --- docs/stablehlo-op-coverage.md | 1 + .../StablehloToEmitC/StablehloToEmitC.cpp | 36 ++++++++++++++++++- .../include/emitc/stablehlo.h | 21 +++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/docs/stablehlo-op-coverage.md b/docs/stablehlo-op-coverage.md index 215b5276..dc6d7c73 100644 --- a/docs/stablehlo-op-coverage.md +++ b/docs/stablehlo-op-coverage.md @@ -57,3 +57,4 @@ The table below shows the supported StableHLO ops. | reduce_window | :white_check_mark: | No support for dilation | | reshape | :heavy_check_mark: | | | select | :heavy_check_mark: | | +| transpose | :heavy_check_mark: | | diff --git a/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp b/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp index 0039a33d..b7886e17 100644 --- a/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp +++ b/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp @@ -423,6 +423,38 @@ class PadOpConversion : public OpConversionPattern { } }; +/// Convert `stablehlo.transpose` into an `emitc.call` operation. +class TransposeOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + TransposeOpConversion(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + +private: + LogicalResult + matchAndRewrite(stablehlo::TransposeOp transposeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringAttr callee = rewriter.getStringAttr("emitc::stablehlo::transpose"); + + SmallVector arguments = + indexSequence(adaptor.getOperands().size(), transposeOp.getContext()); + + arguments.push_back(transposeOp.getPermutation()); + ArrayAttr args = rewriter.getArrayAttr(arguments); + + Type resultType = transposeOp.getResult().getType(); + ArrayAttr templateArgs = rewriter.getArrayAttr({TypeAttr::get(resultType)}); + + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getType(), callee, args, templateArgs, + adaptor.getOperands()); + + return success(); + } +}; + /// Convert `stablehlo.rng` into an `emitc.call` operation. class RngOpConversion : public OpConversionPattern { @@ -552,6 +584,7 @@ void populateStablehloToEmitcPatterns(MLIRContext *ctx, /*explicitResultType=*/true); patterns.add>( ctx, "emitc::stablehlo::select"); + patterns.add(ctx); // Insert patterns for StableHLO RNG ops. patterns.add(ctx); @@ -630,7 +663,8 @@ struct ConvertStablehloToEmitCPass stablehlo::DotOp, stablehlo::PadOp, stablehlo::ReshapeOp, - stablehlo::SelectOp>(); + stablehlo::SelectOp, + stablehlo::TransposeOp>(); // StableHLO RNG ops. target.addIllegalOp(); diff --git a/reference-implementation/include/emitc/stablehlo.h b/reference-implementation/include/emitc/stablehlo.h index f3bf8acd..05ae3069 100644 --- a/reference-implementation/include/emitc/stablehlo.h +++ b/reference-implementation/include/emitc/stablehlo.h @@ -650,6 +650,27 @@ inline Src select(typename replace_element_type::type pred, return z; } +// TransposeOp +// Maps the perms dimension from Dest to Src. +template +inline Dest transpose(Src operand, Tensor1D perms) { + static_assert(is_tensor::value, "Expected tensor argument"); + static_assert(is_tensor::value, "Expected tensor result"); + + // Since emitc::broadcast_in_dim maps the dimensions (argument + // "broadcast_dimensions") from Src to Dest and stablehlo::transpose maps the + // dimensions (argument "perms") from Dest to Src, we have to invert the + // mapping. + Tensor1D broadcast_dimensions; + for (size_t i = 0; i < perms.size(); ++i) { + auto pos = std::find(perms.begin(), perms.end(), i); + assert(pos != std::end(perms)); + int64_t index = std::distance(perms.begin(), pos); + broadcast_dimensions[i] = index; + } + return emitc::broadcast_in_dim(operand, broadcast_dimensions); +} + // RngUniformOp template inline Dest rng_uniform(Tensor low, Tensor high,