Skip to content

Commit

Permalink
Fixed GRU quality issues exposed by e2e tests (#3753)
Browse files Browse the repository at this point in the history
  • Loading branch information
knwng authored Oct 2, 2024
1 parent f8e4a9a commit f0b7ca7
Showing 1 changed file with 42 additions and 42 deletions.
84 changes: 42 additions & 42 deletions lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,11 +1072,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstNone = b.create<ConstantNoneOp>();
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
Value cstTwo = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2));

// Binding arguments
ValueTensorType yTy, Y_hType;
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
return rewriter.notifyMatchFailure(binder.op,
"At least one output must be present");
Expand Down Expand Up @@ -1132,6 +1131,7 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// Validations
auto XShape = xTy.getSizes();
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1];
int64_t input_size = XShape[2];

std::ostringstream oss;
Expand Down Expand Up @@ -1173,6 +1173,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
initial_h =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1) {
initial_h = StaticTranspose(b, initial_h, 0, 1);
}
}

if (binder.tensorOperandAtIndex(sequence_lens, 4))
Expand All @@ -1192,10 +1196,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// fill in B
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
if (B == nullptr) {
SmallVector<int64_t> BShape = {num_directions, 2 * hidden_size};
SmallVector<int64_t> BShape = {num_directions, 6 * hidden_size};
SmallVector<Value> BShapeListContents = {
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2 * hidden_size))};
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(6 * hidden_size))};
Value BShapeList = b.create<PrimListConstructOp>(
b.getType<ListType>(intType), BShapeListContents);
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
Expand Down Expand Up @@ -1256,51 +1260,47 @@ LogicalResult OnnxGruExpander(OpBinder binder,
B_slices[4], B_slices[5]);

// Process inputs based on layout
Value X_processed, initial_h_processed;
ValueTensorType yTy_processed, Y_hType_processed;

if (layout == 0) {
X_processed = X;
initial_h_processed = initial_h_forward;
yTy_processed = yTy;
Y_hType_processed = Y_hType;
} else {
X_processed = b.create<AtenTransposeIntOp>(X.getType(), X, cstZero, cstOne);
initial_h_processed = b.create<AtenTransposeIntOp>(
initial_h.getType(), initial_h_forward, cstZero, cstOne);

auto yTySizes = yTy.getSizes();
auto Y_hTypeSizes = Y_hType.getSizes();

yTy_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{yTySizes[1], yTySizes[0], yTySizes[2],
yTySizes[3]},
yTy.getDtype());

Y_hType_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{Y_hTypeSizes[1], Y_hTypeSizes[0],
Y_hTypeSizes[2]},
Y_hType.getDtype());
if (layout == 1) {
X = StaticTranspose(b, X, 0, 1);
}

// Weights and biases ready. Calling GRU layer to insert the actual ops.
GruLayerOutput gruLayerOutput =
gru_layer(b, X_processed, initial_h_processed, weights, activations,
linear_before_reset);
GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights,
activations, linear_before_reset);

// Process outputs based on layout
Value Y_final, Y_h_final;
if (layout == 0) {
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
Value Y_final;
if (binder.tensorResultTypeAtIndex(yTy, 0)) {
Y_final = cstNone;
} else {
auto Y_transposed = b.create<AtenTransposeIntOp>(
gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne);
Y_final = b.create<AtenUnsqueezeOp>(yTy, Y_transposed, cstTwo);
if (layout == 0) {
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
} else {
Type yTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
yTy.getDtype());
Y_final =
b.create<AtenUnsqueezeOp>(yTy_original, gruLayerOutput.Y, cstOne);
Y_final = StaticTranspose(b, Y_final, 1, 2);
Y_final = StaticTranspose(b, Y_final, 0, 1);
}
}

auto Y_h_transposed = b.create<AtenTransposeIntOp>(
gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne);
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, Y_h_transposed, cstZero);
Value Y_h_final;
if (binder.tensorResultTypeAtIndex(Y_hType, 1)) {
Y_h_final = cstNone;
} else {
if (layout == 0) {
Y_h_final =
b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
} else {
Type y_hTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1, batch_size, hidden_size},
Y_hType.getDtype());
Y_h_final = b.create<AtenUnsqueezeOp>(y_hTy_original, gruLayerOutput.Y_h,
cstZero);
Y_h_final = StaticTranspose(b, Y_h_final, 0, 1);
}
}

rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});
Expand Down

0 comments on commit f0b7ca7

Please sign in to comment.