mirror of https://github.com/llvm/torch-mlir
87 lines
6.1 KiB
MLIR
87 lines
6.1 KiB
MLIR
|
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s
|
||
|
|
||
|
|
||
|
// CHECK-LABEL: func @elementwise$unary(
|
||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||
|
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32>
|
||
|
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
|
||
|
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
|
||
|
// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
|
||
|
// CHECK: linalg.yield %[[TANH]] : f32
|
||
|
// CHECK: } -> tensor<f32>
|
||
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
|
||
|
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
|
||
|
// CHECK: }
|
||
|
func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||
|
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||
|
return %0 : !torch.vtensor<[],f32>
|
||
|
}
|
||
|
|
||
|
// CHECK-LABEL: func @elementwise$binary(
|
||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
|
||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||
|
// CHECK: %[[BUILTIN_ARG0:.*]] = torch.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||
|
// CHECK: %[[BUILTIN_ARG1:.*]] = torch.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
|
||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||
|
// CHECK: %[[ARG0_DIM0:.*]] = memref.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor<?x?xf32>
|
||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||
|
// CHECK: %[[ARG0_DIM1:.*]] = memref.dim %[[BUILTIN_ARG0]], %[[C1]] : tensor<?x?xf32>
|
||
|
// CHECK: %[[C0_2:.*]] = constant 0 : index
|
||
|
// CHECK: %[[ARG1_DIM0:.*]] = memref.dim %[[BUILTIN_ARG1]], %[[C0_2]] : tensor<?xf32>
|
||
|
// CHECK: %[[LEGAL_SIZES:.*]] = cmpi eq, %[[ARG0_DIM1]], %[[ARG1_DIM0]] : index
|
||
|
// CHECK: assert %[[LEGAL_SIZES]], "mismatched size for broadcast"
|
||
|
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[ARG0_DIM0]], %[[ARG0_DIM1]]] : tensor<?x?xf32>
|
||
|
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%[[BUILTIN_ARG0]], %[[BUILTIN_ARG1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[INIT_TENSOR]] : tensor<?x?xf32>) {
|
||
|
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
|
||
|
// CHECK: %[[MUL:.*]] = mulf %[[LHS]], %[[RHS]] : f32
|
||
|
// CHECK: linalg.yield %[[MUL]] : f32
|
||
|
// CHECK: } -> tensor<?x?xf32>
|
||
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||
|
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||
|
func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||
|
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
|
||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||
|
}
|
||
|
|
||
|
// CHECK-LABEL: func @elementwise$ternary(
|
||
|
// CHECK: linalg.generic {indexing_maps = [
|
||
|
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||
|
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||
|
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2)>,
|
||
|
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
|
||
|
func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||
|
%0 = torch.aten.lerp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||
|
return %0 : !torch.vtensor<[?,?,?],f32>
|
||
|
}
|
||
|
|
||
|
// CHECK-LABEL: func @elementwise$with_scalar_capture(
|
||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
|
||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
|
||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||
|
// CHECK: %[[BUILTIN_C1:.*]] = torch.to_i64 %[[C1]]
|
||
|
// CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>]
|
||
|
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
|
||
|
// CHECK: %[[ALPHA:.*]] = sitofp %[[BUILTIN_C1]] : i64 to f32
|
||
|
// CHECK: %[[SCALED:.*]] = mulf %[[RHS]], %[[ALPHA]] : f32
|
||
|
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[SCALED]] : f32
|
||
|
// CHECK: linalg.yield %[[RES]] : f32
|
||
|
// CHECK: } -> tensor<?xf32>
|
||
|
func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
|
||
|
%int1 = torch.constant.int 1
|
||
|
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||
|
return %0 : !torch.vtensor<[?],f32>
|
||
|
}
|
||
|
|
||
|
// CHECK-LABEL: func @elementwise$static_1(
|
||
|
// CHECK: linalg.generic {indexing_maps = [
|
||
|
// CHECK-SAME: affine_map<(d0) -> (d0)>,
|
||
|
// CHECK-SAME: affine_map<(d0) -> (0)>,
|
||
|
// CHECK-SAME: affine_map<(d0) -> (d0)>]
|
||
|
func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?],f32> {
|
||
|
%1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32>
|
||
|
return %1 : !torch.vtensor<[?],f32>
|
||
|
}
|