Skip to content

Commit

Permalink
Add conversions for TOSA avg_pool2d and max_pool2d (#374)
Browse files Browse the repository at this point in the history
Co-authored by: Simon Camphausen <[email protected]>
  • Loading branch information
lucas-camp authored Jul 18, 2023
1 parent 4e5ede5 commit 055375f
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/tosa-op-coverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ The table below shows the supported TOSA ops.
| select | :heavy_check_mark: | |
| **Other ops**
| argmax | :heavy_check_mark: | |
| avg_pool2d | :white_check_mark: | Quantization and and acc_type not supported |
| concat | :heavy_check_mark: | |
| conv2d | :white_check_mark: | Quantization and dilation not supported |
| depthwise_conv2d | :white_check_mark: | Quantization and dilation not supported |
| fully_connected | :white_check_mark: | Quantization not supported |
| gather | :heavy_check_mark: | |
| matmul | :white_check_mark: | Quantization not supported |
| max_pool2d | :white_check_mark: | Quantization not supported |
| reduce_all | :heavy_check_mark: | |
| reduce_any | :heavy_check_mark: | |
| reduce_max | :heavy_check_mark: | |
Expand Down
48 changes: 47 additions & 1 deletion lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,46 @@ class GenericConvOpConversion : public OpConversionPattern<SrcOp> {
StringRef funcName;
};

/// Convert a common `tosa` pooling operation into an `emitc.call`
/// operation.
template <typename SrcOp, typename Adaptor = typename SrcOp::Adaptor>
class GenericPoolOpConversion : public OpConversionPattern<SrcOp> {
using OpConversionPattern<SrcOp>::OpConversionPattern;

public:
GenericPoolOpConversion(MLIRContext *ctx, StringRef funcName)
: OpConversionPattern<SrcOp>(ctx), funcName(funcName) {}

private:
LogicalResult
matchAndRewrite(SrcOp poolOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringAttr callee = rewriter.getStringAttr(funcName);

// TODO: average pool has an acc_type attribute.
// clang-format off
ArrayAttr args = rewriter.getArrayAttr({
rewriter.getIndexAttr(0),
getI64ElementsAttr(poolOp.getPad(), poolOp.getContext()),
getI64ElementsAttr(poolOp.getStride(), poolOp.getContext()),
getI64ElementsAttr(poolOp.getKernel(), poolOp.getContext()),
});
// clang-format on

ArrayAttr templateArgs =
rewriter.getArrayAttr({TypeAttr::get(poolOp.getResult().getType())});

// Create pool op.
rewriter.replaceOpWithNewOp<emitc::CallOp>(poolOp, poolOp.getType(), callee,
args, templateArgs,
adaptor.getOperands());

return success();
}

StringRef funcName;
};

/// Convert `tosa.fully_connected` into an `emitc.call` operation.
class FullyConnectedOpConversion
: public OpConversionPattern<tosa::FullyConnectedOp> {
Expand Down Expand Up @@ -830,6 +870,10 @@ void populateTosaToEmitcPatterns(MLIRContext *ctx,
"emitc::tosa::conv2d");
patterns.add<GenericConvOpConversion<tosa::DepthwiseConv2DOp>>(
ctx, "emitc::tosa::depthwise_conv2d");
patterns.add<GenericPoolOpConversion<tosa::AvgPool2dOp>>(
ctx, "emitc::tosa::avg_pool2d");
patterns.add<GenericPoolOpConversion<tosa::MaxPool2dOp>>(
ctx, "emitc::tosa::max_pool2d");
patterns.add<FullyConnectedOpConversion>(ctx, "emitc::tosa::fully_connected");
patterns.add<GenericOpConversion<tosa::GatherOp>>(
ctx, "emitc::tosa::gather",
Expand Down Expand Up @@ -907,7 +951,9 @@ struct ConvertTosaToEmitCPass
target.addIllegalOp<tosa::SelectOp>();

// Other ops.
target.addIllegalOp<tosa::ConcatOp,
target.addIllegalOp<tosa::AvgPool2dOp,
tosa::MaxPool2dOp,
tosa::ConcatOp,
tosa::Conv2DOp,
tosa::DepthwiseConv2DOp,
tosa::FullyConnectedOp,
Expand Down
115 changes: 115 additions & 0 deletions reference-implementation/include/emitc/tosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,121 @@ Dest depthwise_conv2d(Src input, Weights weights, Tensor1D<int64_t, 4> padding,
return output;
}

// MaxPool2d
template <typename Dest, typename Src>
Dest max_pool2d(Src input, std::array<int64_t, 4> padding,
std::array<int64_t, 2> stride, std::array<int64_t, 2> kernel) {
static_assert(is_tensor_of_dim<4, Src>::value,
"Expected 4 dimensional input");
static_assert(is_tensor_of_dim<4, Dest>::value,
"Expected 4 dimensional output");
using ET_Dest = typename get_element_type<Dest>::type;
assert(stride[0] > 0);
assert(stride[1] > 0);
const int N = input.dim(0);
const int H_IN = input.dim(1);
const int W_IN = input.dim(2);
const int C = input.dim(3);
Dest output;
const int K_H = kernel[0];
const int K_W = kernel[1];
const int S_H = stride[0];
const int S_W = stride[1];
const int pt = padding[0];
const int pb = padding[1];
const int pl = padding[2];
const int pr = padding[3];
const int H_PAD = pt + H_IN + pb;
const int W_PAD = pl + W_IN + pr;
// Pooling
for (int n = 0; n < N; n++) {
for (int h_pad = 0; h_pad < H_PAD - K_H + 1; h_pad += S_H) {
for (int w_pad = 0; w_pad < W_PAD - K_W + 1; w_pad += S_W) {
for (int c = 0; c < C; c++) {
const int h_out = h_pad / S_H;
const int w_out = w_pad / S_W;
output(n, h_out, w_out, c) = std::numeric_limits<ET_Dest>::min();
for (int kh = 0; kh < K_H; kh++) {
for (int kw = 0; kw < K_W; kw++) {
const int h_in = h_pad - pt + kh;
const int w_in = w_pad - pl + kw;
if (h_in < 0 || h_in >= H_IN || w_in < 0 || w_in >= W_IN)
continue;
output(n, h_out, w_out, c) =
std::max(output(n, h_out, w_out, c), input(n, h_in, w_in, c));
}
}
}
}
}
}
return output;
}

// AvgPool2d
template <typename Dest, typename Src>
Dest avg_pool2d(Src input, std::array<int64_t, 4> padding,
std::array<int64_t, 2> stride, std::array<int64_t, 2> kernel) {
static_assert(is_tensor_of_dim<4, Src>::value,
"Expected 4 dimensional input");
static_assert(is_tensor_of_dim<4, Dest>::value,
"Expected 4 dimensional output");

using ET_Dest = typename get_element_type<Dest>::type;
static_assert(std::is_same<ET_Dest, float>::value,
"Only float data type supported");

assert(stride[0] > 0);
assert(stride[1] > 0);

const int N = input.dim(0);
const int H_IN = input.dim(1);
const int W_IN = input.dim(2);
const int C = input.dim(3);

Dest output;

const int K_H = kernel[0];
const int K_W = kernel[1];
const int S_H = stride[0];
const int S_W = stride[1];
const int pt = padding[0];
const int pb = padding[1];
const int pl = padding[2];
const int pr = padding[3];
const int H_PAD = pt + H_IN + pb;
const int W_PAD = pl + W_IN + pr;

// Pooling
for (int n = 0; n < N; n++) {
for (int h_pad = 0; h_pad < H_PAD - K_H + 1; h_pad += S_H) {
for (int w_pad = 0; w_pad < W_PAD - K_W + 1; w_pad += S_W) {
for (int c = 0; c < C; c++) {
const int h_out = h_pad / S_H;
const int w_out = w_pad / S_W;

ET_Dest acc = ET_Dest(0);
size_t count = 0;

for (int kh = 0; kh < K_H; kh++) {
for (int kw = 0; kw < K_W; kw++) {
const int h_in = h_pad - pt + kh;
const int w_in = w_pad - pl + kw;
if (h_in < 0 || h_in >= H_IN || w_in < 0 || w_in >= W_IN)
continue;

count++;
acc += input(n, h_in, w_in, c);
}
}
output(n, h_out, w_out, c) = acc / static_cast<ET_Dest>(count);
}
}
}
}
return output;
}

// FullyConnectedOp
template <typename Dest, typename Src, typename Weights, typename Bias>
Dest fully_connected(Src input, Weights weights, Bias bias) {
Expand Down

0 comments on commit 055375f

Please sign in to comment.