From 8787970afed3c4e1497fb24c4fdeec179fcb61f6 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:54:27 -0700 Subject: [PATCH] [Torch] Fold no-op reshape (#3769) This was preventing dynamic dims in an ONNX model from being reified (causing the generation of `tensor.cast`s and preventing fusion in iree): ```mlir %2 = torch.vtensor.literal(dense<[4, 256]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>] %7 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list %8 = torch.aten.reshape %2, %7 : !torch.vtensor<[2],si64>, !torch.list -> !torch.vtensor<[2],si64> //... chain of foldable ops linking %2 to the `shape` operand of a `torch.aten.broadcast_to ... -> !torch.vtensor<[?,?],si64>` ``` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 13 +++++++++++++ .../jit_ir_importer/build_tools/torch_ods_gen.py | 2 +- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 44bf8ab2e0d4..b1a670b6d48b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11455,6 +11455,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e10564bbe26b..47e77c11f17c 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2261,6 +2261,19 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenReshapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) { + auto selfTy = dyn_cast(getSelf().getType()); + auto opTy = dyn_cast(getType()); + if (selfTy && selfTy == opTy && selfTy.hasSizes() && + selfTy.toBuiltinTensor().hasStaticShape()) + return getSelf(); + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenSelectIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ea5070a8c0bb..ba56f10fbd06 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -856,7 +856,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::repeat : (Tensor, int[]) -> (Tensor)") emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)") emit("aten::tile : (Tensor, int[]) -> (Tensor)") - emit("aten::reshape : (Tensor, int[]) -> (Tensor)") + emit("aten::reshape : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)") emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")