Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for 1-d group convolution (#3770)
Browse files Browse the repository at this point in the history
This commit adds the support for the 1-d depthwise convolution as a
special case of 1-d group convolution.

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Oct 8, 2024
1 parent f6721e5 commit 614fcdd
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 13 deletions.
50 changes: 37 additions & 13 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops
Expand All @@ -1199,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
weightShape[1] == 1) {
// Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
weightShape[2], weightShape[3]};
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
for (unsigned i = 0; i < numSpatialDims; i++) {
collapsedDims.push_back({i + 2});
collapsedShape.push_back(weightShape[i + 2]);
}
Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), weightDTy);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims);
if (!inputZp) {
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
switch (numSpatialDims) {
case 1:
conv = rewriter
.create<linalg::DepthwiseConv1DNcwCwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D and 2D depthwise convolution "
"supported for special case of group convolution");
};
} else {
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D depthwise quantized convolution "
"supported for special case of group convolution");

// currently, the only named depthwise qconv op is nhwc_hwc
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
// linalg conv result nhwc -> nchw
Expand Down Expand Up @@ -1260,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = cast<RankedTensorType>(tensor.getType());
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@
"ContainsIntList_False",
"ContainsIntList_True",
"ContiguousModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
Expand Down Expand Up @@ -3395,6 +3396,7 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
Expand Down Expand Up @@ -4087,6 +4089,7 @@
"ContainsIntList_False",
"ContainsIntList_True",
"Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic",
"Conv2dNoPaddingModule_basic",
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias)


class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 4, 6], torch.float32, True),
([4, 1, 3], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv1d(
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
)


@register_test_case(
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
)
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
inputVec = tu.rand(2, 4, 6)
weight = torch.randn(4, 1, 3)
module.forward(inputVec, weight)


class Conv2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 614fcdd

Please sign in to comment.