Fixed GRU quality issues exposed by e2e tests (#3753)

Issue: https://github.com/nod-ai/SHARK-ModelDev/issues/856

Related tests:
![Screenshot 2024-10-01
175305](https://github.com/user-attachments/assets/0dc0901b-058f-427c-a596-9e806fd38836)
pull/3760/head
Kyle Wang 2024-10-02 14:00:19 -07:00 committed by GitHub
parent f8e4a9a3c2
commit f0b7ca72f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 42 additions and 42 deletions

View File

@ -1072,11 +1072,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstNone = b.create<ConstantNoneOp>();
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
Value cstTwo = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2));
// Binding arguments
ValueTensorType yTy, Y_hType;
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
return rewriter.notifyMatchFailure(binder.op,
"At least one output must be present");
@ -1132,6 +1131,7 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// Validations
auto XShape = xTy.getSizes();
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1];
int64_t input_size = XShape[2];
std::ostringstream oss;
@ -1173,6 +1173,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
initial_h =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1) {
initial_h = StaticTranspose(b, initial_h, 0, 1);
}
}
if (binder.tensorOperandAtIndex(sequence_lens, 4))
@ -1192,10 +1196,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// fill in B
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
if (B == nullptr) {
SmallVector<int64_t> BShape = {num_directions, 2 * hidden_size};
SmallVector<int64_t> BShape = {num_directions, 6 * hidden_size};
SmallVector<Value> BShapeListContents = {
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2 * hidden_size))};
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(6 * hidden_size))};
Value BShapeList = b.create<PrimListConstructOp>(
b.getType<ListType>(intType), BShapeListContents);
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
@ -1256,51 +1260,47 @@ LogicalResult OnnxGruExpander(OpBinder binder,
B_slices[4], B_slices[5]);
// Process inputs based on layout
Value X_processed, initial_h_processed;
ValueTensorType yTy_processed, Y_hType_processed;
if (layout == 0) {
X_processed = X;
initial_h_processed = initial_h_forward;
yTy_processed = yTy;
Y_hType_processed = Y_hType;
} else {
X_processed = b.create<AtenTransposeIntOp>(X.getType(), X, cstZero, cstOne);
initial_h_processed = b.create<AtenTransposeIntOp>(
initial_h.getType(), initial_h_forward, cstZero, cstOne);
auto yTySizes = yTy.getSizes();
auto Y_hTypeSizes = Y_hType.getSizes();
yTy_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{yTySizes[1], yTySizes[0], yTySizes[2],
yTySizes[3]},
yTy.getDtype());
Y_hType_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{Y_hTypeSizes[1], Y_hTypeSizes[0],
Y_hTypeSizes[2]},
Y_hType.getDtype());
if (layout == 1) {
X = StaticTranspose(b, X, 0, 1);
}
// Weights and biases ready. Calling GRU layer to insert the actual ops.
GruLayerOutput gruLayerOutput =
gru_layer(b, X_processed, initial_h_processed, weights, activations,
linear_before_reset);
GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights,
activations, linear_before_reset);
// Process outputs based on layout
Value Y_final, Y_h_final;
if (layout == 0) {
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
Value Y_final;
if (binder.tensorResultTypeAtIndex(yTy, 0)) {
Y_final = cstNone;
} else {
auto Y_transposed = b.create<AtenTransposeIntOp>(
gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne);
Y_final = b.create<AtenUnsqueezeOp>(yTy, Y_transposed, cstTwo);
if (layout == 0) {
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
} else {
Type yTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
yTy.getDtype());
Y_final =
b.create<AtenUnsqueezeOp>(yTy_original, gruLayerOutput.Y, cstOne);
Y_final = StaticTranspose(b, Y_final, 1, 2);
Y_final = StaticTranspose(b, Y_final, 0, 1);
}
}
auto Y_h_transposed = b.create<AtenTransposeIntOp>(
gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne);
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, Y_h_transposed, cstZero);
Value Y_h_final;
if (binder.tensorResultTypeAtIndex(Y_hType, 1)) {
Y_h_final = cstNone;
} else {
if (layout == 0) {
Y_h_final =
b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
} else {
Type y_hTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1, batch_size, hidden_size},
Y_hType.getDtype());
Y_h_final = b.create<AtenUnsqueezeOp>(y_hTy_original, gruLayerOutput.Y_h,
cstZero);
Y_h_final = StaticTranspose(b, Y_h_final, 0, 1);
}
}
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});