Includes stablehlo bump
pull/3766/head
Rob Suderman 2024-10-04 12:08:35 -07:00 committed by GitHub
parent 6e8c7bed4b
commit 2374b9e02d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 51 additions and 42 deletions

@ -1 +1 @@
Subproject commit d418a03e01e6a31b51b0c9dd42ba46da6c47f89d Subproject commit e813750354bbc08551cf23ff559a54b4a9ea1f29

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit c28d55e91b4a5daaff18a33ce7e9bbd0f171256a Subproject commit d40285ef3db0687e3f1e2bb0d716d748485a9739

View File

@ -10,7 +10,7 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project

View File

@ -325,7 +325,8 @@ public:
lhsContractingDim, rhsContractingDim); lhsContractingDim, rhsContractingDim);
output = rewriter output = rewriter
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs, .create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr) dotDimensionNumbers, nullptr,
nullptr)
.getResult(); .getResult();
return success(); return success();
} }
@ -494,7 +495,7 @@ public:
/*lhsContractingDimensions=*/{lhsContractingDim}, /*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim}); /*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>( Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr, nullptr);
Value matmulPlusBias = matmulOutput; Value matmulPlusBias = matmulOutput;
if (!isa<Torch::NoneType>(biasTy)) { if (!isa<Torch::NoneType>(biasTy)) {

View File

@ -2840,8 +2840,12 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Not all dims are valid"); return rewriter.notifyMatchFailure(op, "Not all dims are valid");
} }
auto transposeDimsConst = mlir::tosa::getConstTensor<int64_t>( SmallVector<int32_t> dimListInt32;
rewriter, op.getOperation(), dimListInt, {selfRank}); for (auto v : dimListInt)
dimListInt32.push_back(v);
auto transposeDimsConst = mlir::tosa::getConstTensor<int32_t>(
rewriter, op.getOperation(), dimListInt32, {selfRank});
rewriter.replaceOpWithNewOp<tosa::TransposeOp>( rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),

View File

@ -3169,9 +3169,13 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
if torch_version_for_comparison() > version.parse("2.4.0.dev"): if torch_version_for_comparison() > version.parse("2.4.0.dev"):
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
"ElementwiseCreateComplexModule_basic", "ElementwiseCreateComplexModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
} }
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | { FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
"ElementwiseCreateComplexModule_basic", "ElementwiseCreateComplexModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
} }

View File

@ -696,8 +696,8 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> // CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
// CHECK: } // CHECK: }
@ -890,15 +890,15 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
// CHECK-LABEL: @torch.aten.max.dim$basic( // CHECK-LABEL: @torch.aten.max.dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>) // CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> // CHECK-DAG: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> // CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true // CHECK-DAG: %[[VAL_TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 // CHECK-DAG: %[[VAL_I2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> // CHECK-DAG: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK-DAG: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK-DAG: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> // CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
@ -1378,16 +1378,16 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
// CHECK-LABEL: func.func @torch.aten.min.dim$basic( // CHECK-LABEL: func.func @torch.aten.min.dim$basic(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { // CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> // CHECK-DAG: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> // CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true // CHECK-DAG: %[[VAL_3:.*]] = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK-DAG: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> // CHECK-DAG: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> // CHECK-DAG: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK-DAG: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK-DAG: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32> // CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
// CHECK: } // CHECK: }
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {

View File

@ -4,17 +4,17 @@
// CHECK-LABEL: func.func @scan_1d_inclusive( // CHECK-LABEL: func.func @scan_1d_inclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> // CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> // CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32> // CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) { // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32): // CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 // CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 // CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
// CHECK: } // CHECK: }
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32> // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true) %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
@ -32,8 +32,10 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> // CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32> // CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> // CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32> // CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32> // CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) { // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
@ -41,8 +43,6 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 // CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 // CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
// CHECK: } // CHECK: }
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32> // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false) %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
@ -62,14 +62,14 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> // CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> // CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32 // CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
// CHECK: } // CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func.func @scatter_update_scalar_1D( func.func @scatter_update_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>, %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
@ -90,7 +90,8 @@ func.func @scatter_update_scalar_1D(
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> // CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> // CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
@ -99,7 +100,6 @@ func.func @scatter_update_scalar_1D(
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32 // CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
// CHECK: tm_tensor.yield %[[ADD]] : i32 // CHECK: tm_tensor.yield %[[ADD]] : i32
// CHECK: } // CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func.func @scatter_add_scalar_1D( func.func @scatter_add_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>, %original: tensor<8xi32>, %indices: tensor<3x1xi32>,

View File

@ -34,5 +34,5 @@ def test_enable_ir_printing():
) )
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) # CHECK: // -----// IR Dump After Inliner (inline)
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} { # CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {