mirror of https://github.com/llvm/torch-mlir
Fixed GRU quality issues exposed by e2e tests (#3753)
Issue: https://github.com/nod-ai/SHARK-ModelDev/issues/856 Related tests: ![Screenshot 2024-10-01 175305](https://github.com/user-attachments/assets/0dc0901b-058f-427c-a596-9e806fd38836)pull/3760/head
parent
f8e4a9a3c2
commit
f0b7ca72f5
|
@ -1072,11 +1072,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
Value cstNone = b.create<ConstantNoneOp>();
|
||||
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
|
||||
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
|
||||
Value cstTwo = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2));
|
||||
|
||||
// Binding arguments
|
||||
ValueTensorType yTy, Y_hType;
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"At least one output must be present");
|
||||
|
@ -1132,6 +1131,7 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
// Validations
|
||||
auto XShape = xTy.getSizes();
|
||||
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
|
||||
int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1];
|
||||
int64_t input_size = XShape[2];
|
||||
|
||||
std::ostringstream oss;
|
||||
|
@ -1173,6 +1173,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
||||
initial_h =
|
||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||
} else {
|
||||
if (layout == 1) {
|
||||
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (binder.tensorOperandAtIndex(sequence_lens, 4))
|
||||
|
@ -1192,10 +1196,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
// fill in B
|
||||
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
||||
if (B == nullptr) {
|
||||
SmallVector<int64_t> BShape = {num_directions, 2 * hidden_size};
|
||||
SmallVector<int64_t> BShape = {num_directions, 6 * hidden_size};
|
||||
SmallVector<Value> BShapeListContents = {
|
||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
|
||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2 * hidden_size))};
|
||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(6 * hidden_size))};
|
||||
Value BShapeList = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(intType), BShapeListContents);
|
||||
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
|
||||
|
@ -1256,51 +1260,47 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
B_slices[4], B_slices[5]);
|
||||
|
||||
// Process inputs based on layout
|
||||
Value X_processed, initial_h_processed;
|
||||
ValueTensorType yTy_processed, Y_hType_processed;
|
||||
|
||||
if (layout == 0) {
|
||||
X_processed = X;
|
||||
initial_h_processed = initial_h_forward;
|
||||
yTy_processed = yTy;
|
||||
Y_hType_processed = Y_hType;
|
||||
} else {
|
||||
X_processed = b.create<AtenTransposeIntOp>(X.getType(), X, cstZero, cstOne);
|
||||
initial_h_processed = b.create<AtenTransposeIntOp>(
|
||||
initial_h.getType(), initial_h_forward, cstZero, cstOne);
|
||||
|
||||
auto yTySizes = yTy.getSizes();
|
||||
auto Y_hTypeSizes = Y_hType.getSizes();
|
||||
|
||||
yTy_processed = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{yTySizes[1], yTySizes[0], yTySizes[2],
|
||||
yTySizes[3]},
|
||||
yTy.getDtype());
|
||||
|
||||
Y_hType_processed = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{Y_hTypeSizes[1], Y_hTypeSizes[0],
|
||||
Y_hTypeSizes[2]},
|
||||
Y_hType.getDtype());
|
||||
if (layout == 1) {
|
||||
X = StaticTranspose(b, X, 0, 1);
|
||||
}
|
||||
|
||||
// Weights and biases ready. Calling GRU layer to insert the actual ops.
|
||||
GruLayerOutput gruLayerOutput =
|
||||
gru_layer(b, X_processed, initial_h_processed, weights, activations,
|
||||
linear_before_reset);
|
||||
GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights,
|
||||
activations, linear_before_reset);
|
||||
|
||||
// Process outputs based on layout
|
||||
Value Y_final, Y_h_final;
|
||||
if (layout == 0) {
|
||||
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
|
||||
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
|
||||
Value Y_final;
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0)) {
|
||||
Y_final = cstNone;
|
||||
} else {
|
||||
auto Y_transposed = b.create<AtenTransposeIntOp>(
|
||||
gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne);
|
||||
Y_final = b.create<AtenUnsqueezeOp>(yTy, Y_transposed, cstTwo);
|
||||
if (layout == 0) {
|
||||
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
|
||||
} else {
|
||||
Type yTy_original = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
|
||||
yTy.getDtype());
|
||||
Y_final =
|
||||
b.create<AtenUnsqueezeOp>(yTy_original, gruLayerOutput.Y, cstOne);
|
||||
Y_final = StaticTranspose(b, Y_final, 1, 2);
|
||||
Y_final = StaticTranspose(b, Y_final, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
auto Y_h_transposed = b.create<AtenTransposeIntOp>(
|
||||
gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne);
|
||||
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, Y_h_transposed, cstZero);
|
||||
Value Y_h_final;
|
||||
if (binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||
Y_h_final = cstNone;
|
||||
} else {
|
||||
if (layout == 0) {
|
||||
Y_h_final =
|
||||
b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
|
||||
} else {
|
||||
Type y_hTy_original = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1, batch_size, hidden_size},
|
||||
Y_hType.getDtype());
|
||||
Y_h_final = b.create<AtenUnsqueezeOp>(y_hTy_original, gruLayerOutput.Y_h,
|
||||
cstZero);
|
||||
Y_h_final = StaticTranspose(b, Y_h_final, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});
|
||||
|
|
Loading…
Reference in New Issue