diff --git a/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp index b18cd09f030a..e7ab690e0ff3 100644 --- a/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp +++ b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp @@ -1072,11 +1072,10 @@ LogicalResult OnnxGruExpander(OpBinder binder, Value cstNone = b.create(); Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); - Value cstTwo = b.create(intType, b.getI64IntegerAttr(2)); // Binding arguments ValueTensorType yTy, Y_hType; - if (binder.tensorResultTypeAtIndex(yTy, 0) || + if (binder.tensorResultTypeAtIndex(yTy, 0) && binder.tensorResultTypeAtIndex(Y_hType, 1)) { return rewriter.notifyMatchFailure(binder.op, "At least one output must be present"); @@ -1132,6 +1131,7 @@ LogicalResult OnnxGruExpander(OpBinder binder, // Validations auto XShape = xTy.getSizes(); int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0]; + int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1]; int64_t input_size = XShape[2]; std::ostringstream oss; @@ -1173,6 +1173,10 @@ LogicalResult OnnxGruExpander(OpBinder binder, Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); initial_h = b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } else { + if (layout == 1) { + initial_h = StaticTranspose(b, initial_h, 0, 1); + } } if (binder.tensorOperandAtIndex(sequence_lens, 4)) @@ -1192,10 +1196,10 @@ LogicalResult OnnxGruExpander(OpBinder binder, // fill in B Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); if (B == nullptr) { - SmallVector BShape = {num_directions, 2 * hidden_size}; + SmallVector BShape = {num_directions, 6 * hidden_size}; SmallVector BShapeListContents = { b.create(intType, b.getI64IntegerAttr(num_directions)), - b.create(intType, b.getI64IntegerAttr(2 * hidden_size))}; + b.create(intType, b.getI64IntegerAttr(6 * hidden_size))}; Value BShapeList = b.create( b.getType(intType), BShapeListContents); auto BType = b.getType(BShape, wTy.getDtype()); @@ -1256,51 +1260,47 @@ LogicalResult OnnxGruExpander(OpBinder binder, B_slices[4], B_slices[5]); // Process inputs based on layout - Value X_processed, initial_h_processed; - ValueTensorType yTy_processed, Y_hType_processed; - - if (layout == 0) { - X_processed = X; - initial_h_processed = initial_h_forward; - yTy_processed = yTy; - Y_hType_processed = Y_hType; - } else { - X_processed = b.create(X.getType(), X, cstZero, cstOne); - initial_h_processed = b.create( - initial_h.getType(), initial_h_forward, cstZero, cstOne); - - auto yTySizes = yTy.getSizes(); - auto Y_hTypeSizes = Y_hType.getSizes(); - - yTy_processed = b.getType( - llvm::SmallVector{yTySizes[1], yTySizes[0], yTySizes[2], - yTySizes[3]}, - yTy.getDtype()); - - Y_hType_processed = b.getType( - llvm::SmallVector{Y_hTypeSizes[1], Y_hTypeSizes[0], - Y_hTypeSizes[2]}, - Y_hType.getDtype()); + if (layout == 1) { + X = StaticTranspose(b, X, 0, 1); } // Weights and biases ready. Calling GRU layer to insert the actual ops. - GruLayerOutput gruLayerOutput = - gru_layer(b, X_processed, initial_h_processed, weights, activations, - linear_before_reset); + GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights, + activations, linear_before_reset); // Process outputs based on layout - Value Y_final, Y_h_final; - if (layout == 0) { - Y_final = b.create(yTy, gruLayerOutput.Y, cstOne); - Y_h_final = b.create(Y_hType, gruLayerOutput.Y_h, cstZero); + Value Y_final; + if (binder.tensorResultTypeAtIndex(yTy, 0)) { + Y_final = cstNone; } else { - auto Y_transposed = b.create( - gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne); - Y_final = b.create(yTy, Y_transposed, cstTwo); + if (layout == 0) { + Y_final = b.create(yTy, gruLayerOutput.Y, cstOne); + } else { + Type yTy_original = b.getType( + llvm::SmallVector{seq_len, 1, batch_size, hidden_size}, + yTy.getDtype()); + Y_final = + b.create(yTy_original, gruLayerOutput.Y, cstOne); + Y_final = StaticTranspose(b, Y_final, 1, 2); + Y_final = StaticTranspose(b, Y_final, 0, 1); + } + } - auto Y_h_transposed = b.create( - gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne); - Y_h_final = b.create(Y_hType, Y_h_transposed, cstZero); + Value Y_h_final; + if (binder.tensorResultTypeAtIndex(Y_hType, 1)) { + Y_h_final = cstNone; + } else { + if (layout == 0) { + Y_h_final = + b.create(Y_hType, gruLayerOutput.Y_h, cstZero); + } else { + Type y_hTy_original = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, + Y_hType.getDtype()); + Y_h_final = b.create(y_hTy_original, gruLayerOutput.Y_h, + cstZero); + Y_h_final = StaticTranspose(b, Y_h_final, 0, 1); + } } rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});