Skip to content

Commit

Permalink
[torch-mlir][sparse] preserve sparsity during lowering torch to linalg (
Browse files Browse the repository at this point in the history
#2809)

This preserves sparsity at the most obvious places of lowering TORCH
tensors to MLIR RankedTensorType tensors. Other places are marked for
audit. With some initial lowering tests.
  • Loading branch information
aartbik authored Jan 26, 2024
1 parent da7c6d2 commit 46a25d7
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 8 deletions.
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
return success();
}

// TODO: audit possibility of sparsity on these tensors
Type adjustedResultType = RankedTensorType::get(
makeShapeLLVMCompatible(outputShape), resultType.getElementType());
Type adjustedInputType = RankedTensorType::get(
Expand Down Expand Up @@ -1005,6 +1006,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
intermediateShape.push_back(sum);
}

// TODO: audit possibility of sparsity on these tensor
Type intermediateResultType =
RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape),
resultType.getElementType());
Expand Down Expand Up @@ -1657,6 +1659,7 @@ class ConvertAtenSliceScatterOp
auto srcType = src.getType().cast<RankedTensorType>();
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
// TODO: audit possibility of sparsity on these tensor
auto abstractSrcType = RankedTensorType::get(
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
Value abstractSrc =
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ namespace {
//
// TODO: Find an optimal lowering.
// current lowering is not optimal for bags of large embeddings.
// Since it traverses the output tensor multiple times.
//
// Since it traverses the output tensor multiple times.
//
//

class ConvertAtenEmbeddingBagPaddingIdxOp
Expand Down
11 changes: 6 additions & 5 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
// TODO: Improve usage of static shape information.
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
ShapedType::kDynamic);
auto lhsBroadcastType =
RankedTensorType::get(lhsTargetShape, lhsType.getElementType());
auto lhsBroadcastType = RankedTensorType::get(
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType,
broadcastedLhs))) {
Expand All @@ -387,8 +387,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
}
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
ShapedType::kDynamic);
auto rhsBroadcastType =
RankedTensorType::get(rhsTargetShape, rhsType.getElementType());
auto rhsBroadcastType = RankedTensorType::get(
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType,
broadcastedRhs))) {
Expand Down Expand Up @@ -880,7 +880,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
if(numSpacialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");

// Special depthwise case
auto inShape = makeShapeTorchCompatible(
input.getType().cast<RankedTensorType>().getShape());
Expand All @@ -894,6 +894,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
(weightShape[0] == kUnknownSize ? kUnknownSize
: weightShape[0] * weightShape[1]),
weightShape[2], weightShape[3]};
// TODO: audit possibility of sparsity on this tensor
Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), elementType);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
*pad = castIntToIndex(b, loc, *pad);

Type elementType = input.getType().cast<RankedTensorType>().getElementType();
// TODO: audit possibility of sparsity on this tensor
Type inputType =
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
SmallVector<int64_t>(inRank, kUnknownSize))),
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ TensorType ValueTensorType::toBuiltinTensor() const {
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
if (!elementType)
return nullptr;
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType);
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
getOptionalSparsity());
}

LogicalResult
Expand Down
36 changes: 36 additions & 0 deletions test/Conversion/TorchToLinalg/sparse.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s

// -----

#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>

// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
// CHECK-LABEL: func.func @sum(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32>
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[64,64],f32,#[[$CSR]]> -> tensor<64x64xf32, #[[$CSR]]>
// CHECK: linalg.generic {{{.*}}} ins(%[[S]] : tensor<64x64xf32, #[[$CSR]]>)
func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%0 = torch.aten.sum %arg0, %none
: !torch.vtensor<[64,64],f32,#CSR>, !torch.none -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}

// -----

#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>

// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
// CHECK-LABEL: func.func @SpMM(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>,
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32>
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]>
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
// CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>)
func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
%arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> {
%0 = torch.aten.matmul %arg0, %arg1
: !torch.vtensor<[8,16],f32,#CSR>,
!torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
return %0 : !torch.vtensor<[8,8],f32>
}

0 comments on commit 46a25d7

Please sign in to comment.