onnx.LSTM - bidirectional, layout attr (#3771)

- Support Bidirectional LSTM (utilising the forward LSTM layer with
flipped Inputs and Outputs)
- Support layout 1 
- Support default cases for attr `clip` and `input_forget`
- Support returning partial outputs (1-3)  
- fixes for alt_e2e_tests lstm tests (1,2,3)
pull/3776/head
Phaneesh Barwaria 2024-10-08 23:59:49 +05:30 committed by GitHub
parent 58489faf7f
commit 7830c00ca2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 332 additions and 69 deletions

View File

@ -34,6 +34,7 @@ struct OpBinder {
Location getLoc() { return op->getLoc(); }
int getNumOperands() { return op->getNumOperands(); }
int getNumResults() { return op->getNumResults(); }
// Operand matches of different arities.
ParseResult tensorOperand(Value &value0) {

View File

@ -661,8 +661,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
std::string direction;
ValueTensorType yTy, Y_hType, Y_cType;
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
binder.tensorResultTypeAtIndex(Y_hType, 1) ||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
binder.tensorResultTypeAtIndex(Y_hType, 1) &&
binder.tensorResultTypeAtIndex(Y_cType, 2)) {
return rewriter.notifyMatchFailure(binder.op,
"At least one outputs must be present");
@ -686,51 +686,110 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
auto xTy = cast<ValueTensorType>(X.getType());
auto wTy = cast<ValueTensorType>(W.getType());
Value B;
if (binder.tensorOperandAtIndex(B, 3)) {
B = b.create<AtenZerosOp>(W.getType(), W);
}
// TODO: add defaults for activation_alpha acticvation_beta attributes
llvm::SmallVector<std::string> activationsList;
if (binder.stringArrayAttr(activationsList, "activations"))
return rewriter.notifyMatchFailure(
binder.op, "Missing required attribute; activations");
LstmActivations activations;
activations.f = "Sigmoid";
activations.g = "Tanh";
activations.h = "Tanh";
if (activationsList.size() == 3) {
activations.f = activationsList[0];
activations.g = activationsList[1];
activations.h = activationsList[2];
} else if (activationsList.size() != 0) {
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
direction != "forward" && direction != "bidirectional")
return rewriter.notifyMatchFailure(
binder.op, "activations must be empty have 3 elements, but " +
binder.op, "Unsupported direction attribute value. "
"Only 'forward' / 'bidrectional' are supported but '" +
direction + "' is provided.");
int64_t num_directions = 1 + (direction == "bidirectional");
bool isBidirectional = direction == "bidirectional";
// There can be backward activations too
// if backward -> look for 6 atcivations (what happens when only three?)
int64_t num_activations = activationsList.size();
if (num_activations != 0 && num_activations != 3 && num_activations != 6) {
return rewriter.notifyMatchFailure(
binder.op, "activations must either be empty (default), have 3 elements"
" (forward) or, have 6 elements (bidirectional), but " +
std::to_string(activationsList.size()) +
" are provided.");
}
// TODO : Add checks, defaults and fails for inputs - sequence_lens, P and
// attrs- clip, input_forget, layout
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
direction != "forward")
Value B;
if (binder.tensorOperandAtIndex(B, 3)) {
Value none = b.create<ConstantNoneOp>();
Value cstHiddenx8 = b.create<ConstantIntOp>(
b.getType<IntType>(), b.getI64IntegerAttr(8 * hidden_size));
Value cstNumDir = b.create<ConstantIntOp>(
b.getType<IntType>(), b.getI64IntegerAttr(num_directions));
auto BType = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{num_directions, 8 * hidden_size},
cast<ValueTensorType>(W.getType()).getDtype());
Value zerosShapeList = b.create<PrimListConstructOp>(
b.getType<ListType>(b.getType<IntType>()),
SmallVector<Value>{cstNumDir, cstHiddenx8});
B = b.create<AtenZerosOp>(BType, zerosShapeList, none, none, none, none);
}
LstmActivations activations, activationsRev;
// Default case (both forward and reverse)
activations.f = "Sigmoid";
activations.g = "Tanh";
activations.h = "Tanh";
activationsRev.f = "Sigmoid";
activationsRev.g = "Tanh";
activationsRev.h = "Tanh";
// forward only (also to be added for bidirectional case)
if (num_activations >= 3) {
activations.f = activationsList[0];
activations.g = activationsList[1];
activations.h = activationsList[2];
}
// bidirectional
if (num_activations == 6) {
activationsRev.f = activationsList[3];
activationsRev.g = activationsList[4];
activationsRev.h = activationsList[5];
}
float clip;
if (!binder.f32FloatAttr(clip, "clip", 0.0) && clip != 0.0)
return rewriter.notifyMatchFailure(binder.op,
"Unsupported direction attribute value. "
"Only 'forward' is supported but '" +
direction + "' is provided.");
int64_t num_directions = 1 + (direction == "bidirectional");
"clip attribute not supported");
int64_t input_forget;
if (!binder.s64IntegerAttr(input_forget, "input_forget", 0) &&
input_forget != 0)
return rewriter.notifyMatchFailure(
binder.op, "only input_forget = 0 supported. Got input_forgt = " +
std::to_string(input_forget));
int64_t layout;
if (!binder.s64IntegerAttr(layout, "layout", 0) && layout != 0 && layout != 1)
return rewriter.notifyMatchFailure(
binder.op, "invalid value of layout attribute, expecting 0 / 1 got " +
std::to_string(layout));
auto XShape = xTy.getSizes();
int64_t batch_size = XShape[1];
int64_t seq_len, batch_size;
if (layout == 0) {
seq_len = XShape[0];
batch_size = XShape[1];
} else {
seq_len = XShape[1];
batch_size = XShape[0];
}
int64_t input_size = XShape[2];
if (num_directions != wTy.getSizes()[0])
return rewriter.notifyMatchFailure(
binder.op, "num_directions (" + std::to_string(num_directions) +
") does not match the first dimension of wTy (" +
std::to_string(wTy.getSizes()[0]) + ")");
if (num_directions != 1)
return rewriter.notifyMatchFailure(
binder.op, "num_directions (" + std::to_string(num_directions) +
") is not equal to 1");
if (4 * hidden_size != wTy.getSizes()[1])
return rewriter.notifyMatchFailure(
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
@ -746,6 +805,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value R_forward = getDirection(b, 0, R);
Value B_forward = getDirection(b, 0, B);
Value W_reverse, R_reverse, B_reverse;
if (isBidirectional) {
W_reverse = getDirection(b, 1, W);
R_reverse = getDirection(b, 1, R);
B_reverse = getDirection(b, 1, B);
}
auto hTy = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
xTy.getDtype());
@ -770,29 +836,44 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value initial_h;
if (binder.tensorOperandAtIndex(initial_h, 5)) {
// default created for layout 0
initial_h =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1)
initial_h = StaticTranspose(b, initial_h, 0, 1);
}
Value initial_c;
if (binder.tensorOperandAtIndex(initial_c, 6)) {
// default created for layout 0
initial_c =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1)
initial_c = StaticTranspose(b, initial_c, 0, 1);
}
// convert X from layout 1 to layout 0
if (layout == 1)
X = StaticTranspose(b, X, 0, 1);
// X, initial_h, initial_c are now in layout 0
Value initial_h_forward = getDirection(b, 0, initial_h);
Value initial_c_forward = getDirection(b, 0, initial_c);
if (num_directions != 1) {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported num_directions. Only 1 is supported but " +
std::to_string(num_directions) + " is provided.");
// TODO: support bidirectional LSTM by doing both directions and replacing
// Unsqueeze with Stack
Value initial_h_reverse, initial_c_reverse;
if (isBidirectional) {
initial_h_reverse = getDirection(b, 1, initial_h);
initial_c_reverse = getDirection(b, 1, initial_c);
}
// Everything hereon is for the forward direction, with the direction
// dimention squeezed out.
LstmWeights weights; // weights and biases
// Everything hereon is for the forward direction (unless in bidirectional if
// block), with the direction dimention squeezed out and all inputs in layout
// 0 format
LstmWeights weights, weightsRev; // weights and biases
auto intConst = [&](int64_t val) {
return b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(val));
@ -804,6 +885,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value recurrentWeightsEndIdx = intConst(8 * hidden_size);
auto biasType = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype());
// forward
Value Wb = b.create<AtenSliceTensorOp>(biasType,
/*input=*/B_forward,
/*dim=*/cstZero,
@ -816,6 +898,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
/*start=*/recurrentWeightsStartIdx,
/*end=*/recurrentWeightsEndIdx,
/*step=*/cstOne);
Value Wb_reverse, Rb_reverse;
if (isBidirectional) {
// reverse
Wb_reverse = b.create<AtenSliceTensorOp>(biasType,
/*input=*/B_reverse,
/*dim=*/cstZero,
/*start=*/cstZero,
/*end=*/inputWeightsEndIdx,
/*step=*/cstOne);
Rb_reverse = b.create<AtenSliceTensorOp>(biasType,
/*input=*/B_reverse,
/*dim=*/cstZero,
/*start=*/recurrentWeightsStartIdx,
/*end=*/recurrentWeightsEndIdx,
/*step=*/cstOne);
}
// gate splitting
auto gateBiasType = b.getType<ValueTensorType>(
@ -833,61 +931,164 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value forgetGateWeightsEndIdx = intConst(3 * hidden_size);
Value cellGateWeightsEndIdx = intConst(4 * hidden_size);
auto sliceIOFC = [&](std::function<Value(Value, Value)> slicerFunction) {
auto sliceIOFC = [&](std::function<Value(Value, Value, Value)> slicerFunction,
Value WoB) {
// slice into 4 components and return tuple
return std::make_tuple(
slicerFunction(cstZero, inputGateWeightsEndIdx),
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx),
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx),
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx));
slicerFunction(cstZero, inputGateWeightsEndIdx, WoB),
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx, WoB),
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx, WoB),
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx, WoB));
};
auto sliceGateBias = [&](Value startIdx, Value endIdx) {
return b.create<AtenSliceTensorOp>(gateBiasType, Wb, cstZero, startIdx,
auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
endIdx, cstOne);
};
std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) =
sliceIOFC(sliceGateBias);
sliceIOFC(sliceGateBias, Wb);
auto sliceGateBiasR = [&](Value startIdx, Value endIdx) {
return b.create<AtenSliceTensorOp>(gateBiasType, Rb, cstZero, startIdx,
if (isBidirectional)
std::tie(weightsRev.Wb_i, weightsRev.Wb_o, weightsRev.Wb_f,
weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse);
auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
endIdx, cstOne);
};
std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) =
sliceIOFC(sliceGateBiasR);
sliceIOFC(sliceGateBiasR, Rb);
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) {
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, W_forward, cstZero,
if (isBidirectional)
std::tie(weightsRev.Rb_i, weightsRev.Rb_o, weightsRev.Rb_f,
weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse);
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, WoB, cstZero,
startIdx, endIdx, cstOne);
};
std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) =
sliceIOFC(sliceGateWeightsIH);
sliceIOFC(sliceGateWeightsIH, W_forward);
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) {
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, R_forward, cstZero,
if (isBidirectional)
std::tie(weightsRev.W_i, weightsRev.W_o, weightsRev.W_f, weightsRev.W_c) =
sliceIOFC(sliceGateWeightsIH, W_reverse);
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, WoB, cstZero,
startIdx, endIdx, cstOne);
};
std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) =
sliceIOFC(sliceGateWeightsHH);
sliceIOFC(sliceGateWeightsHH, R_forward);
if (isBidirectional)
std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) =
sliceIOFC(sliceGateWeightsHH, R_reverse);
LstmLayerOutput lstmLayerOutput = lstm_layer(
b, X, initial_h_forward, initial_c_forward, weights, activations);
auto Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>(
Value Y_h_result, Y_c_result, Y_result;
// if forward (unidirectional) unsqueeze and output
auto YallDtype =
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype();
auto Y_h_Y_c_uni_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1, batch_size, hidden_size}, YallDtype);
auto Y_uni_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
YallDtype);
auto Y_h_Y_c_res_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype());
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero);
Value Y_c_unsqueezed = b.create<AtenUnsqueezeOp>(
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero);
YallDtype);
auto Y_res_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
hidden_size},
YallDtype);
Value Y_h_forward =
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero);
Value Y_c_forward =
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero);
// unsqueeze num_directions dim1 of Y
// to create the onnx.LSTM output shape [seq_length, num_directions,
// batch_size, hidden_size]
Value Y_unsqueezed =
b.create<AtenUnsqueezeOp>(yTy, lstmLayerOutput.Y, cstOne);
Value Y_forward =
b.create<AtenUnsqueezeOp>(Y_uni_type, lstmLayerOutput.Y, cstOne);
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed,
Y_c_unsqueezed});
Y_result = Y_forward;
Y_h_result = Y_h_forward;
Y_c_result = Y_c_forward;
// add bidrectional reverse layer
// this is just flip X, lstm layer, flip results, stack
// flip X
Value dim0, X_reverse, Y_h_reverse, Y_c_reverse, Y_reverse_unflipped,
Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list;
LstmLayerOutput revLstmLayerOutput;
if (isBidirectional) {
dim0 = b.create<PrimListConstructOp>(b.getType<ListType>(intType),
SmallVector<Value>{cstZero});
X_reverse = b.create<AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
revLstmLayerOutput =
lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse,
weightsRev, activationsRev);
// unsqueeze Y_rev, Y_h_rev, Y_c_rev
Y_h_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
revLstmLayerOutput.Y_h, cstZero);
Y_c_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
revLstmLayerOutput.Y_c, cstZero);
Y_reverse_unflipped =
b.create<AtenUnsqueezeOp>(Y_uni_type, revLstmLayerOutput.Y, cstOne);
// flip Y_rev on dim 0 [seq_len]
Y_reverse = b.create<AtenFlipOp>(Y_uni_type, Y_reverse_unflipped, dim0);
// Concat forward and reverse results on dim 1
Y_output_list =
b.create<PrimListConstructOp>(b.getType<ListType>(Y_uni_type),
SmallVector<Value>{Y_forward, Y_reverse});
Y_result = b.create<AtenCatOp>(Y_res_type, Y_output_list, cstOne);
// Concat forward and reverse results on dim 0
Y_h_output_list = b.create<PrimListConstructOp>(
b.getType<ListType>(Y_h_Y_c_uni_type),
SmallVector<Value>{Y_h_forward, Y_h_reverse});
Y_h_result =
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_h_output_list, cstZero);
Y_c_output_list = b.create<PrimListConstructOp>(
b.getType<ListType>(Y_h_Y_c_uni_type),
SmallVector<Value>{Y_c_forward, Y_c_reverse});
Y_c_result =
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_c_output_list, cstZero);
}
if (layout == 1) {
// Update Y, Y_h, Y_c results to layout 1
Y_result = StaticTranspose(b, Y_result, 1, 2);
Y_result = StaticTranspose(b, Y_result, 0, 1);
Y_h_result = StaticTranspose(b, Y_h_result, 0, 1);
Y_c_result = StaticTranspose(b, Y_c_result, 0, 1);
}
// Only add outputs specified in onnx output node
SmallVector<Value> actualOutputs = {Y_result, Y_h_result, Y_c_result},
outputs;
ValueTensorType resTy;
for (int i = 0; i < binder.getNumResults(); ++i) {
if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) {
outputs.push_back(cstNone);
} else {
outputs.push_back(actualOutputs[i]);
}
}
rewriter.replaceOp(binder.op, outputs);
return success();
}

View File

@ -16,10 +16,71 @@
// CHECK-DAG: torch.prim.Loop.condition
// CHECK-DAG: }
// CHECK: }
module {
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
}
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
}
// -----
// CHECK-LABEL: func.func @test_lstm_bidirectional_with_initial_bias(
// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[32,32,192],f32>,
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[2,192,192],f32>,
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[2,192,48],f32>,
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[2,384],f32>)
// CHECK: %[[FORWARD_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP_FWD:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y_FWD:.*]], %[[INITIAL_H_FWD:.*]], %[[INITIAL_C_FWD:.*]]) {
// CHECK: ^bb0(%[[FORWARD_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_FWD:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>):
// CHECK-DAG: torch.aten.select.int
// CHECK-DAG: torch.aten.linear
// CHECK-DAG: torch.aten.sigmoid
// CHECK-DAG: torch.aten.tanh
// CHECK-DAG: torch.prim.Loop.condition
// CHECK: }
// CHECK: torch.aten.flip
// CHECK: %[[REVERSE_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS_REV:.*]], %[[LOOP_COND_REV:.*]], init(%[[Y_REV:.*]], %[[INITIAL_H_REV:.*]], %[[INITIAL_C_REV:.*]]) {
// CHECK: ^bb0(%[[REVERSE_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_REV:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>):
// CHECK-DAG: torch.aten.select.int
// CHECK-DAG: torch.aten.linear
// CHECK-DAG: torch.aten.sigmoid
// CHECK-DAG: torch.aten.tanh
// CHECK-DAG: torch.prim.Loop.condition
// CHECK: }
// CHECK: torch.aten.flip
// CHECK: return %[[Y:.*]], %[[Y_H:.*]], %[[Y_C:.*]] : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
// CHECK: }
func.func @test_lstm_bidirectional_with_initial_bias(%arg0: !torch.vtensor<[32,32,192],f32>, %arg1: !torch.vtensor<[2,192,192],f32>, %arg2: !torch.vtensor<[2,192,48],f32>, %arg3: !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64, torch.onnx.layout = 0 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
}
// -----
// CHECK-LABEL: func.func @test_lstm_batchwise_two_outputs(
// CHECK-SAME: %[[X_LAYOUT_1:.*]]: !torch.vtensor<[3,1,2],f32>,
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,28,2],f32>,
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,28,7],f32>)
// CHECK: torch.aten.transpose.int
// CHECK: %[[LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) {
// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[1,3,7],f32>, %[[H_PREV:.*]]: !torch.vtensor<[3,7],f32>, %[[C_PREV:.*]]: !torch.vtensor<[3,7],f32>):
// CHECK-DAG: torch.aten.select.int
// CHECK-DAG: torch.aten.linear
// CHECK-DAG: torch.aten.sigmoid
// CHECK-DAG: torch.aten.tanh
// CHECK-DAG: torch.prim.Loop.condition
// CHECK: }
// CHECK-DAG: torch.aten.transpose.int
// CHECK-DAG: torch.aten.transpose.int
// CHECK-DAG: torch.aten.transpose.int
// CHECK-DAG: torch.aten.transpose.int
// CHECK: return %[[Y:.*]], %[[Y_H:.*]] : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
// CHECK: }
func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,28,2],f32>, %arg2: !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>)
return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
}