Update llvm-project to b44b3494f60296db6aca38a14cab061d9b747a0a (#2511)

The main purpose is to bring in the new mesh dialect change.
https://github.com/llvm/llvm-project/pull/68007
pull/2517/head snapshot-20231017.994
Chi_Liu 2023-10-16 19:29:48 -07:00 committed by GitHub
parent f2c53b8ca5
commit 14a4da923b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 91 deletions

@ -1 +1 @@
Subproject commit d13da154a7c7eff77df8686b2de1cfdfa7cc7029
Subproject commit b44b3494f60296db6aca38a14cab061d9b747a0a

View File

@ -15,6 +15,7 @@ add_mlir_library(TorchMLIRRefBackend
MLIRIR
MLIRTransforms
MLIRMathTransforms
MLIRLinalgTransforms
)
mlir_check_all_link_libraries(TorchMLIRRefBackend)

View File

@ -41,16 +41,16 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
// -----
// CHECK-LABEL: func.func @torch.aten.leaky_relu$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e-01
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_0]], %[[VAL_3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_0]], %[[VAL_5]] : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e-01
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_6]] : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%fp0 = torch.constant.float 1.000000e-01
@ -155,13 +155,13 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
// -----
// CHECK-LABEL: func.func @torch.aten.add$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
@ -175,13 +175,13 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch
// -----
// CHECK-LABEL: func.func @torch.aten.sub$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_6]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
@ -195,13 +195,14 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch
// -----
// CHECK-LABEL: func.func @torch.aten.mul$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
@ -210,14 +211,15 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch
// -----
// CHECK-LABEL: func.func @torch.aten.div$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RCP:.*]] = tosa.reciprocal %[[ARG1_BUILTIN]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[RCP]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
@ -394,13 +396,13 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>)
// -----
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
@ -415,13 +417,13 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to
// -----
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
@ -502,7 +504,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to
// -----
// CHECK-LABEL: func.func @torch.aten.native_batch_norm$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>}> : () -> tensor<4xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32>
@ -518,8 +520,8 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to
// CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
// CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_9]], %[[VAL_12]] : (tensor<4x1xf32>, tensor<f32>) -> tensor<4x1xf32>
// CHECK: %[[VAL_15:.*]] = tosa.rsqrt %[[VAL_14]] : (tensor<4x1xf32>) -> tensor<4x1xf32>
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
// CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_17]], %[[VAL_11]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32>
@ -538,7 +540,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32
// -----
// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
@ -579,9 +581,9 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3
// -----
// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32>
@ -595,22 +597,22 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3
// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 5, 1, 1, 1>} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array<i64: 5, 1, 1, 1>} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1, 2, 2, 3>} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 2, 2, 3>} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
// CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_23]], %[[VAL_26]] : (tensor<5x1x1x1xf32>, tensor<f32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_32:.*]] = tosa.add %[[VAL_31]], %[[VAL_25]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32>
// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32>
@ -683,12 +685,12 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>
// -----
// CHECK-LABEL: func.func @torch.aten.log2$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x1xf32>) -> tensor<1x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
// CHECK: }
@ -996,7 +998,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 12, 1>} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array<i64: 8, 3>} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32>
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 1, 8>} : (tensor<8x1xi32>) -> tensor<1x8xi32>
// CHECK: %[[VAL_17:.*]] = tosa.gather %[[VAL_11]], %[[VAL_16]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
@ -1019,7 +1021,7 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor<2x2xi32>, tensor<i32>) -> tensor<2x2xi32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor<i32>) -> tensor<2x2xi32>
// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<2x2xi32>) -> tensor<2x2xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64>
@ -1039,7 +1041,7 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc
// CHECK: %[[VAL_3:.*]] = torch.constant.int 256
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32>
// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<1x1x128x128xi32>, tensor<i32>) -> tensor<1x1x128x128xi32>
// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64>
@ -1138,17 +1140,17 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
// -----
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5:.*]] : (tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3:.*]], %[[VAL_6:.*]] {shift = 0 : i32} : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_8]] {shift = 0 : i32} : (tensor<f32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_10:.*]] = tosa.sub %[[VAL_3]], %[[VAL_9]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32>
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_6]] {shift = 0 : i8} : (tensor<f32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_1]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32>
// CHECK: }
func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
%int2 = torch.constant.int 2
@ -1159,23 +1161,23 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to
// -----
// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08
// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1>
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.float 1.000000e-08
// CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[VAL_6:.*]] = torch.constant.bool false
// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_8:.*]] = tosa.abs %[[VAL_7]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_3]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_9]] {shift = 0 : i8} : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_13:.*]] = tosa.add %[[VAL_12]], %[[VAL_11]] : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_14:.*]] = tosa.greater_equal %[[VAL_13]], %[[VAL_8]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1>
// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1>
// CHECK: return %[[VAL_15]] : !torch.vtensor<[5,5],i1>
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
%float1.000000e-08 = torch.constant.float 1.000000e-08

View File

@ -1,9 +1,11 @@
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: torch.aten.mul.Scalar$mixed_type
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16>
// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i32} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16>
// CHECK-LABEL: func.func @torch.aten.mul.Scalar$mixed_type(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>) -> tensor<5xbf16> {
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16>
// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16>
// CHECK: return %[[VAL_2]] : tensor<5xbf16>
// CHECK: }
func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> {
%float2.000000e00 = torch.constant.float 2.000000e+00
%0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16>
@ -88,12 +90,14 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],
// -----
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-LABEL: func.func @torch.aten.div.Tensor$mixed_type_fp(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>) -> tensor<?x?xf32> {
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: return %[[VAL_4]] : tensor<?x?xf32>
// CHECK: }
func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>