// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor // CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor) -> tensor // CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor to tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: } func.func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32> } // ----- // CHECK-LABEL: func.func @elementwise$binary( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-DAG: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK-DAG: %[[BUILTIN_ARG1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[ARG0_DIM1:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C1]] : tensor // CHECK: %[[C0_2:.*]] = arith.constant 0 : index // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG1]], %[[C0_2]] : tensor // CHECK: %[[LEGAL_SIZES:.*]] = arith.cmpi eq, %[[ARG0_DIM1]], %[[ARG1_DIM0]] : index // CHECK: assert %[[LEGAL_SIZES]], "mismatched size for broadcast" // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty(%[[ARG0_DIM0]], %[[ARG0_DIM1]]) : tensor // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%[[BUILTIN_ARG0]], %[[BUILTIN_ARG1]] : tensor, tensor) outs(%[[INIT_TENSOR]] : tensor) { // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: %[[MUL:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 // CHECK: linalg.yield %[[MUL]] : f32 // CHECK: } -> tensor // CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } // ----- // CHECK-LABEL: func.func @elementwise$ternary( // CHECK: linalg.generic {indexing_maps = [ // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d1, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?],f32> { %0 = torch.aten.lerp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?,?],f32> return %0 : !torch.vtensor<[?,?,?],f32> } // ----- // CHECK-LABEL: func.func @elementwise$with_scalar_capture( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[BUILTIN_C1:.*]] = arith.constant 1 : i64 // CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: %[[ALPHA:.*]] = arith.sitofp %[[BUILTIN_C1]] : i64 to f32 // CHECK: %[[SCALED:.*]] = arith.mulf %[[RHS]], %[[ALPHA]] : f32 // CHECK: %[[RES:.*]] = arith.addf %[[LHS]], %[[SCALED]] : f32 // CHECK: linalg.yield %[[RES]] : f32 // CHECK: } -> tensor func.func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32> } // ----- // CHECK-LABEL: func.func @elementwise$static_1( // CHECK: linalg.generic {indexing_maps = [ // CHECK-SAME: affine_map<(d0) -> (d0)>, // CHECK-SAME: affine_map<(d0) -> (0)>, // CHECK-SAME: affine_map<(d0) -> (d0)>] func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?],f32> { %1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32> return %1 : !torch.vtensor<[?],f32> } // ----- // CHECK-LABEL: func.func @elementwise_sinh // CHECK: linalg.generic // CHECK: math.sinh func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> { %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> }