mirror of https://github.com/llvm/torch-mlir
25 lines
1003 B
MLIR
25 lines
1003 B
MLIR
|
// RUN: torch-mlir-opt -canonicalize -split-input-file %s | FileCheck %s
|
||
|
|
||
|
// CHECK-LABEL: func.func @tensor.cast(
|
||
|
func.func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> {
|
||
|
%init = tensor.empty() : tensor<128xi32>
|
||
|
%c0 = tensor.empty() : tensor<i32>
|
||
|
|
||
|
%casted_arg0 = tensor.cast %arg0 : tensor<128xi32> to tensor<?xi32>
|
||
|
%casted_init = tensor.cast %init : tensor<128xi32> to tensor<?xi32>
|
||
|
// CHECK: tm_tensor.scan
|
||
|
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<128xi32>)
|
||
|
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<128xi32>, tensor<i32>)
|
||
|
%0, %1 = tm_tensor.scan dimension(0) inclusive(true)
|
||
|
ins(%casted_arg0 : tensor<?xi32>)
|
||
|
outs(%casted_init, %c0: tensor<?xi32>, tensor<i32>) {
|
||
|
^bb0(%barg0 : i32, %barg1 : i32, %barg2 : i32):
|
||
|
%sum = arith.addi %barg0, %barg1 : i32
|
||
|
tm_tensor.yield %sum : i32
|
||
|
} -> tensor<?xi32>, tensor<i32>
|
||
|
|
||
|
%2 = tensor.cast %0: tensor<?xi32> to tensor<128xi32>
|
||
|
|
||
|
return %2: tensor<128xi32>
|
||
|
}
|