// RUN: torch-mlir-opt -canonicalize -split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @tensor.cast( func.func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> { %init = tensor.empty() : tensor<128xi32> %c0 = tensor.empty() : tensor %casted_arg0 = tensor.cast %arg0 : tensor<128xi32> to tensor %casted_init = tensor.cast %init : tensor<128xi32> to tensor // CHECK: tm_tensor.scan // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<128xi32>) // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<128xi32>, tensor) %0, %1 = tm_tensor.scan dimension(0) inclusive(true) ins(%casted_arg0 : tensor) outs(%casted_init, %c0: tensor, tensor) { ^bb0(%barg0 : i32, %barg1 : i32, %barg2 : i32): %sum = arith.addi %barg0, %barg1 : i32 tm_tensor.yield %sum : i32 } -> tensor, tensor %2 = tensor.cast %0: tensor to tensor<128xi32> return %2: tensor<128xi32> }