From 7830c00ca2fc110a534f23b55faf435baf03a2bc Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Tue, 8 Oct 2024 23:59:49 +0530 Subject: [PATCH] onnx.LSTM - bidirectional, layout attr (#3771) - Support Bidirectional LSTM (utilising the forward LSTM layer with flipped Inputs and Outputs) - Support layout 1 - Support default cases for attr `clip` and `input_forget` - Support returning partial outputs (1-3) - fixes for alt_e2e_tests lstm tests (1,2,3) --- .../Conversion/TorchOnnxToTorch/Patterns.h | 1 + .../OnnxRecurrentLayerOpExpanders.cpp | 327 ++++++++++++++---- .../Conversion/TorchOnnxToTorch/ops/lstm.mlir | 73 +++- 3 files changed, 332 insertions(+), 69 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index f71deaff2..431d014ad 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -34,6 +34,7 @@ struct OpBinder { Location getLoc() { return op->getLoc(); } int getNumOperands() { return op->getNumOperands(); } + int getNumResults() { return op->getNumResults(); } // Operand matches of different arities. ParseResult tensorOperand(Value &value0) { diff --git a/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp index e7ab690e0..317a5459e 100644 --- a/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp +++ b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp @@ -661,8 +661,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder, std::string direction; ValueTensorType yTy, Y_hType, Y_cType; - if (binder.tensorResultTypeAtIndex(yTy, 0) || - binder.tensorResultTypeAtIndex(Y_hType, 1) || + if (binder.tensorResultTypeAtIndex(yTy, 0) && + binder.tensorResultTypeAtIndex(Y_hType, 1) && binder.tensorResultTypeAtIndex(Y_cType, 2)) { return rewriter.notifyMatchFailure(binder.op, "At least one outputs must be present"); @@ -686,51 +686,110 @@ LogicalResult OnnxLstmExpander(OpBinder binder, auto xTy = cast(X.getType()); auto wTy = cast(W.getType()); - Value B; - if (binder.tensorOperandAtIndex(B, 3)) { - B = b.create(W.getType(), W); - } + + // TODO: add defaults for activation_alpha acticvation_beta attributes llvm::SmallVector activationsList; if (binder.stringArrayAttr(activationsList, "activations")) return rewriter.notifyMatchFailure( binder.op, "Missing required attribute; activations"); - LstmActivations activations; - activations.f = "Sigmoid"; - activations.g = "Tanh"; - activations.h = "Tanh"; - if (activationsList.size() == 3) { - activations.f = activationsList[0]; - activations.g = activationsList[1]; - activations.h = activationsList[2]; - } else if (activationsList.size() != 0) { + if (!binder.customOpNameStringAttr(direction, "direction", "forward") && + direction != "forward" && direction != "bidirectional") return rewriter.notifyMatchFailure( - binder.op, "activations must be empty have 3 elements, but " + + binder.op, "Unsupported direction attribute value. " + "Only 'forward' / 'bidrectional' are supported but '" + + direction + "' is provided."); + int64_t num_directions = 1 + (direction == "bidirectional"); + bool isBidirectional = direction == "bidirectional"; + // There can be backward activations too + // if backward -> look for 6 atcivations (what happens when only three?) + + int64_t num_activations = activationsList.size(); + if (num_activations != 0 && num_activations != 3 && num_activations != 6) { + return rewriter.notifyMatchFailure( + binder.op, "activations must either be empty (default), have 3 elements" + " (forward) or, have 6 elements (bidirectional), but " + std::to_string(activationsList.size()) + " are provided."); } + // TODO : Add checks, defaults and fails for inputs - sequence_lens, P and + // attrs- clip, input_forget, layout - if (!binder.customOpNameStringAttr(direction, "direction", "forward") && - direction != "forward") + Value B; + if (binder.tensorOperandAtIndex(B, 3)) { + Value none = b.create(); + Value cstHiddenx8 = b.create( + b.getType(), b.getI64IntegerAttr(8 * hidden_size)); + Value cstNumDir = b.create( + b.getType(), b.getI64IntegerAttr(num_directions)); + auto BType = b.getType( + llvm::SmallVector{num_directions, 8 * hidden_size}, + cast(W.getType()).getDtype()); + Value zerosShapeList = b.create( + b.getType(b.getType()), + SmallVector{cstNumDir, cstHiddenx8}); + B = b.create(BType, zerosShapeList, none, none, none, none); + } + + LstmActivations activations, activationsRev; + // Default case (both forward and reverse) + activations.f = "Sigmoid"; + activations.g = "Tanh"; + activations.h = "Tanh"; + activationsRev.f = "Sigmoid"; + activationsRev.g = "Tanh"; + activationsRev.h = "Tanh"; + + // forward only (also to be added for bidirectional case) + if (num_activations >= 3) { + activations.f = activationsList[0]; + activations.g = activationsList[1]; + activations.h = activationsList[2]; + } + + // bidirectional + if (num_activations == 6) { + activationsRev.f = activationsList[3]; + activationsRev.g = activationsList[4]; + activationsRev.h = activationsList[5]; + } + + float clip; + if (!binder.f32FloatAttr(clip, "clip", 0.0) && clip != 0.0) return rewriter.notifyMatchFailure(binder.op, - "Unsupported direction attribute value. " - "Only 'forward' is supported but '" + - direction + "' is provided."); - int64_t num_directions = 1 + (direction == "bidirectional"); + "clip attribute not supported"); + + int64_t input_forget; + if (!binder.s64IntegerAttr(input_forget, "input_forget", 0) && + input_forget != 0) + return rewriter.notifyMatchFailure( + binder.op, "only input_forget = 0 supported. Got input_forgt = " + + std::to_string(input_forget)); + + int64_t layout; + if (!binder.s64IntegerAttr(layout, "layout", 0) && layout != 0 && layout != 1) + return rewriter.notifyMatchFailure( + binder.op, "invalid value of layout attribute, expecting 0 / 1 got " + + std::to_string(layout)); auto XShape = xTy.getSizes(); - int64_t batch_size = XShape[1]; + int64_t seq_len, batch_size; + if (layout == 0) { + seq_len = XShape[0]; + batch_size = XShape[1]; + } else { + seq_len = XShape[1]; + batch_size = XShape[0]; + } + int64_t input_size = XShape[2]; if (num_directions != wTy.getSizes()[0]) return rewriter.notifyMatchFailure( binder.op, "num_directions (" + std::to_string(num_directions) + ") does not match the first dimension of wTy (" + std::to_string(wTy.getSizes()[0]) + ")"); - if (num_directions != 1) - return rewriter.notifyMatchFailure( - binder.op, "num_directions (" + std::to_string(num_directions) + - ") is not equal to 1"); + if (4 * hidden_size != wTy.getSizes()[1]) return rewriter.notifyMatchFailure( binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) + @@ -746,6 +805,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Value R_forward = getDirection(b, 0, R); Value B_forward = getDirection(b, 0, B); + Value W_reverse, R_reverse, B_reverse; + if (isBidirectional) { + W_reverse = getDirection(b, 1, W); + R_reverse = getDirection(b, 1, R); + B_reverse = getDirection(b, 1, B); + } + auto hTy = b.getType( llvm::SmallVector{num_directions, batch_size, hidden_size}, xTy.getDtype()); @@ -770,29 +836,44 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Value initial_h; if (binder.tensorOperandAtIndex(initial_h, 5)) { + // default created for layout 0 initial_h = b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } else { + if (layout == 1) + initial_h = StaticTranspose(b, initial_h, 0, 1); } + Value initial_c; if (binder.tensorOperandAtIndex(initial_c, 6)) { + // default created for layout 0 initial_c = b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } else { + if (layout == 1) + initial_c = StaticTranspose(b, initial_c, 0, 1); } + // convert X from layout 1 to layout 0 + if (layout == 1) + X = StaticTranspose(b, X, 0, 1); + + // X, initial_h, initial_c are now in layout 0 + Value initial_h_forward = getDirection(b, 0, initial_h); Value initial_c_forward = getDirection(b, 0, initial_c); - if (num_directions != 1) { - return rewriter.notifyMatchFailure( - binder.op, "Unsupported num_directions. Only 1 is supported but " + - std::to_string(num_directions) + " is provided."); - // TODO: support bidirectional LSTM by doing both directions and replacing - // Unsqueeze with Stack + Value initial_h_reverse, initial_c_reverse; + if (isBidirectional) { + initial_h_reverse = getDirection(b, 1, initial_h); + initial_c_reverse = getDirection(b, 1, initial_c); } - // Everything hereon is for the forward direction, with the direction - // dimention squeezed out. - LstmWeights weights; // weights and biases + // Everything hereon is for the forward direction (unless in bidirectional if + // block), with the direction dimention squeezed out and all inputs in layout + // 0 format + + LstmWeights weights, weightsRev; // weights and biases auto intConst = [&](int64_t val) { return b.create(intType, b.getI64IntegerAttr(val)); @@ -804,6 +885,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Value recurrentWeightsEndIdx = intConst(8 * hidden_size); auto biasType = b.getType( llvm::SmallVector{hidden_size * 4}, wTy.getDtype()); + // forward Value Wb = b.create(biasType, /*input=*/B_forward, /*dim=*/cstZero, @@ -816,6 +898,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder, /*start=*/recurrentWeightsStartIdx, /*end=*/recurrentWeightsEndIdx, /*step=*/cstOne); + Value Wb_reverse, Rb_reverse; + if (isBidirectional) { + // reverse + Wb_reverse = b.create(biasType, + /*input=*/B_reverse, + /*dim=*/cstZero, + /*start=*/cstZero, + /*end=*/inputWeightsEndIdx, + /*step=*/cstOne); + Rb_reverse = b.create(biasType, + /*input=*/B_reverse, + /*dim=*/cstZero, + /*start=*/recurrentWeightsStartIdx, + /*end=*/recurrentWeightsEndIdx, + /*step=*/cstOne); + } // gate splitting auto gateBiasType = b.getType( @@ -833,61 +931,164 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Value forgetGateWeightsEndIdx = intConst(3 * hidden_size); Value cellGateWeightsEndIdx = intConst(4 * hidden_size); - auto sliceIOFC = [&](std::function slicerFunction) { + auto sliceIOFC = [&](std::function slicerFunction, + Value WoB) { // slice into 4 components and return tuple return std::make_tuple( - slicerFunction(cstZero, inputGateWeightsEndIdx), - slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx), - slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx), - slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx)); + slicerFunction(cstZero, inputGateWeightsEndIdx, WoB), + slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx, WoB), + slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx, WoB), + slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx, WoB)); }; - auto sliceGateBias = [&](Value startIdx, Value endIdx) { - return b.create(gateBiasType, Wb, cstZero, startIdx, + auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateBiasType, WoB, cstZero, startIdx, endIdx, cstOne); }; std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) = - sliceIOFC(sliceGateBias); + sliceIOFC(sliceGateBias, Wb); - auto sliceGateBiasR = [&](Value startIdx, Value endIdx) { - return b.create(gateBiasType, Rb, cstZero, startIdx, + if (isBidirectional) + std::tie(weightsRev.Wb_i, weightsRev.Wb_o, weightsRev.Wb_f, + weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse); + + auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateBiasType, WoB, cstZero, startIdx, endIdx, cstOne); }; std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) = - sliceIOFC(sliceGateBiasR); + sliceIOFC(sliceGateBiasR, Rb); - auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) { - return b.create(gateWeightsTypeIH, W_forward, cstZero, + if (isBidirectional) + std::tie(weightsRev.Rb_i, weightsRev.Rb_o, weightsRev.Rb_f, + weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse); + + auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateWeightsTypeIH, WoB, cstZero, startIdx, endIdx, cstOne); }; std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) = - sliceIOFC(sliceGateWeightsIH); + sliceIOFC(sliceGateWeightsIH, W_forward); - auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) { - return b.create(gateWeightsTypeHH, R_forward, cstZero, + if (isBidirectional) + std::tie(weightsRev.W_i, weightsRev.W_o, weightsRev.W_f, weightsRev.W_c) = + sliceIOFC(sliceGateWeightsIH, W_reverse); + + auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateWeightsTypeHH, WoB, cstZero, startIdx, endIdx, cstOne); }; + std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) = - sliceIOFC(sliceGateWeightsHH); + sliceIOFC(sliceGateWeightsHH, R_forward); + + if (isBidirectional) + std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) = + sliceIOFC(sliceGateWeightsHH, R_reverse); + LstmLayerOutput lstmLayerOutput = lstm_layer( b, X, initial_h_forward, initial_c_forward, weights, activations); - auto Y_h_Y_c_unsqueezed_type = b.getType( + Value Y_h_result, Y_c_result, Y_result; + + // if forward (unidirectional) unsqueeze and output + auto YallDtype = + cast(lstmLayerOutput.Y_h.getType()).getDtype(); + auto Y_h_Y_c_uni_type = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, YallDtype); + auto Y_uni_type = b.getType( + llvm::SmallVector{seq_len, 1, batch_size, hidden_size}, + YallDtype); + auto Y_h_Y_c_res_type = b.getType( llvm::SmallVector{num_directions, batch_size, hidden_size}, - cast(lstmLayerOutput.Y_h.getType()).getDtype()); - Value Y_h_unsqueezed = b.create( - Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero); - Value Y_c_unsqueezed = b.create( - Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero); + YallDtype); + auto Y_res_type = b.getType( + llvm::SmallVector{seq_len, num_directions, batch_size, + hidden_size}, + YallDtype); + + Value Y_h_forward = + b.create(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero); + + Value Y_c_forward = + b.create(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero); // unsqueeze num_directions dim1 of Y // to create the onnx.LSTM output shape [seq_length, num_directions, // batch_size, hidden_size] - Value Y_unsqueezed = - b.create(yTy, lstmLayerOutput.Y, cstOne); + Value Y_forward = + b.create(Y_uni_type, lstmLayerOutput.Y, cstOne); - rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed, - Y_c_unsqueezed}); + Y_result = Y_forward; + Y_h_result = Y_h_forward; + Y_c_result = Y_c_forward; + + // add bidrectional reverse layer + // this is just flip X, lstm layer, flip results, stack + // flip X + Value dim0, X_reverse, Y_h_reverse, Y_c_reverse, Y_reverse_unflipped, + Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list; + LstmLayerOutput revLstmLayerOutput; + if (isBidirectional) { + dim0 = b.create(b.getType(intType), + SmallVector{cstZero}); + X_reverse = b.create(xTy, X, dim0); // flip along seq_len dim + revLstmLayerOutput = + lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse, + weightsRev, activationsRev); + + // unsqueeze Y_rev, Y_h_rev, Y_c_rev + Y_h_reverse = b.create(Y_h_Y_c_uni_type, + revLstmLayerOutput.Y_h, cstZero); + Y_c_reverse = b.create(Y_h_Y_c_uni_type, + revLstmLayerOutput.Y_c, cstZero); + Y_reverse_unflipped = + b.create(Y_uni_type, revLstmLayerOutput.Y, cstOne); + + // flip Y_rev on dim 0 [seq_len] + Y_reverse = b.create(Y_uni_type, Y_reverse_unflipped, dim0); + + // Concat forward and reverse results on dim 1 + Y_output_list = + b.create(b.getType(Y_uni_type), + SmallVector{Y_forward, Y_reverse}); + Y_result = b.create(Y_res_type, Y_output_list, cstOne); + + // Concat forward and reverse results on dim 0 + Y_h_output_list = b.create( + b.getType(Y_h_Y_c_uni_type), + SmallVector{Y_h_forward, Y_h_reverse}); + Y_h_result = + b.create(Y_h_Y_c_res_type, Y_h_output_list, cstZero); + + Y_c_output_list = b.create( + b.getType(Y_h_Y_c_uni_type), + SmallVector{Y_c_forward, Y_c_reverse}); + Y_c_result = + b.create(Y_h_Y_c_res_type, Y_c_output_list, cstZero); + } + + if (layout == 1) { + // Update Y, Y_h, Y_c results to layout 1 + Y_result = StaticTranspose(b, Y_result, 1, 2); + Y_result = StaticTranspose(b, Y_result, 0, 1); + Y_h_result = StaticTranspose(b, Y_h_result, 0, 1); + Y_c_result = StaticTranspose(b, Y_c_result, 0, 1); + } + + // Only add outputs specified in onnx output node + SmallVector actualOutputs = {Y_result, Y_h_result, Y_c_result}, + outputs; + ValueTensorType resTy; + for (int i = 0; i < binder.getNumResults(); ++i) { + if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) { + outputs.push_back(cstNone); + } else { + outputs.push_back(actualOutputs[i]); + } + } + + rewriter.replaceOp(binder.op, outputs); return success(); } diff --git a/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir b/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir index bb1821088..1d230e79e 100644 --- a/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir +++ b/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir @@ -16,10 +16,71 @@ // CHECK-DAG: torch.prim.Loop.condition // CHECK-DAG: } // CHECK: } -module { - func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - %none = torch.constant.none - %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) - return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32> - } + +func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lstm_bidirectional_with_initial_bias( +// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[32,32,192],f32>, +// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[2,192,192],f32>, +// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[2,192,48],f32>, +// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[2,384],f32>) +// CHECK: %[[FORWARD_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP_FWD:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y_FWD:.*]], %[[INITIAL_H_FWD:.*]], %[[INITIAL_C_FWD:.*]]) { +// CHECK: ^bb0(%[[FORWARD_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_FWD:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>): +// CHECK-DAG: torch.aten.select.int +// CHECK-DAG: torch.aten.linear +// CHECK-DAG: torch.aten.sigmoid +// CHECK-DAG: torch.aten.tanh +// CHECK-DAG: torch.prim.Loop.condition +// CHECK: } +// CHECK: torch.aten.flip +// CHECK: %[[REVERSE_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS_REV:.*]], %[[LOOP_COND_REV:.*]], init(%[[Y_REV:.*]], %[[INITIAL_H_REV:.*]], %[[INITIAL_C_REV:.*]]) { +// CHECK: ^bb0(%[[REVERSE_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_REV:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>): +// CHECK-DAG: torch.aten.select.int +// CHECK-DAG: torch.aten.linear +// CHECK-DAG: torch.aten.sigmoid +// CHECK-DAG: torch.aten.tanh +// CHECK-DAG: torch.prim.Loop.condition +// CHECK: } +// CHECK: torch.aten.flip +// CHECK: return %[[Y:.*]], %[[Y_H:.*]], %[[Y_C:.*]] : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32> +// CHECK: } + +func.func @test_lstm_bidirectional_with_initial_bias(%arg0: !torch.vtensor<[32,32,192],f32>, %arg1: !torch.vtensor<[2,192,192],f32>, %arg2: !torch.vtensor<[2,192,48],f32>, %arg3: !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64, torch.onnx.layout = 0 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lstm_batchwise_two_outputs( +// CHECK-SAME: %[[X_LAYOUT_1:.*]]: !torch.vtensor<[3,1,2],f32>, +// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,28,2],f32>, +// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,28,7],f32>) +// CHECK: torch.aten.transpose.int +// CHECK: %[[LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) { +// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[1,3,7],f32>, %[[H_PREV:.*]]: !torch.vtensor<[3,7],f32>, %[[C_PREV:.*]]: !torch.vtensor<[3,7],f32>): +// CHECK-DAG: torch.aten.select.int +// CHECK-DAG: torch.aten.linear +// CHECK-DAG: torch.aten.sigmoid +// CHECK-DAG: torch.aten.tanh +// CHECK-DAG: torch.prim.Loop.condition +// CHECK: } +// CHECK-DAG: torch.aten.transpose.int +// CHECK-DAG: torch.aten.transpose.int +// CHECK-DAG: torch.aten.transpose.int +// CHECK-DAG: torch.aten.transpose.int +// CHECK: return %[[Y:.*]], %[[Y_H:.*]] : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32> +// CHECK: } + +func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,28,2],f32>, %arg2: !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) + return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32> }