Skip to content

Commit

Permalink
[Torch] Fold no-op reshape (#3769)
Browse files Browse the repository at this point in the history
This was preventing dynamic dims in an ONNX model from being reified (causing the generation of `tensor.cast`s and preventing fusion in iree):

```mlir
%2 = torch.vtensor.literal(dense<[4, 256]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>]
%7 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%8 = torch.aten.reshape %2, %7 : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
//... chain of foldable ops linking %2 to the `shape` operand of a `torch.aten.broadcast_to ... -> !torch.vtensor<[?,?],si64>`
```
  • Loading branch information
IanWood1 authored Oct 11, 2024
1 parent 2665ed3 commit 8787970
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11455,6 +11455,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,19 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenReshapeOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
auto selfTy = dyn_cast<ValueTensorType>(getSelf().getType());
auto opTy = dyn_cast<ValueTensorType>(getType());
if (selfTy && selfTy == opTy && selfTy.hasSizes() &&
selfTy.toBuiltinTensor().hasStaticShape())
return getSelf();
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenSelectIntOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)")
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")
Expand Down

0 comments on commit 8787970

Please sign in to comment.