mirror of https://github.com/llvm/torch-mlir
26 lines
2.0 KiB
MLIR
26 lines
2.0 KiB
MLIR
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @test_lstm_basic(
|
|
// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[15,2,4],f32>,
|
|
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,12,4],f32>,
|
|
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,12,3],f32>,
|
|
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[1,24],f32>)
|
|
// CHECK: %[[LOOP_RESULT:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS:.*]], %[[ENTER_LOOP:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) {
|
|
// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[15,2,3],f32>, %[[H_PREV:.*]]: !torch.vtensor<[2,3],f32>, %[[C_PREV:.*]]: !torch.vtensor<[2,3],f32>):
|
|
// CHECK-DAG: torch.aten.select.int
|
|
// CHECK-DAG: torch.aten.linear
|
|
// CHECK-DAG: torch.aten.sigmoid
|
|
// CHECK-DAG: torch.aten.tanh
|
|
// CHECK-DAG: torch.prim.Loop.condition
|
|
// CHECK-DAG: }
|
|
// CHECK: }
|
|
module {
|
|
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
|
%none = torch.constant.none
|
|
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
|
|
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
|
|
}
|
|
}
|