mirror of https://github.com/llvm/torch-mlir
parent
6e8c7bed4b
commit
2374b9e02d
|
@ -1 +1 @@
|
|||
Subproject commit d418a03e01e6a31b51b0c9dd42ba46da6c47f89d
|
||||
Subproject commit e813750354bbc08551cf23ff559a54b4a9ea1f29
|
|
@ -1 +1 @@
|
|||
Subproject commit c28d55e91b4a5daaff18a33ce7e9bbd0f171256a
|
||||
Subproject commit d40285ef3db0687e3f1e2bb0d716d748485a9739
|
|
@ -10,7 +10,7 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
|
|
|
@ -325,7 +325,8 @@ public:
|
|||
lhsContractingDim, rhsContractingDim);
|
||||
output = rewriter
|
||||
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||
dotDimensionNumbers, nullptr)
|
||||
dotDimensionNumbers, nullptr,
|
||||
nullptr)
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
|
@ -494,7 +495,7 @@ public:
|
|||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
|
||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr, nullptr);
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
if (!isa<Torch::NoneType>(biasTy)) {
|
||||
|
|
|
@ -2840,8 +2840,12 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
||||
}
|
||||
|
||||
auto transposeDimsConst = mlir::tosa::getConstTensor<int64_t>(
|
||||
rewriter, op.getOperation(), dimListInt, {selfRank});
|
||||
SmallVector<int32_t> dimListInt32;
|
||||
for (auto v : dimListInt)
|
||||
dimListInt32.push_back(v);
|
||||
|
||||
auto transposeDimsConst = mlir::tosa::getConstTensor<int32_t>(
|
||||
rewriter, op.getOperation(), dimListInt32, {selfRank});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
||||
|
|
|
@ -3169,9 +3169,13 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
|||
if torch_version_for_comparison() > version.parse("2.4.0.dev"):
|
||||
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseTanModule_basic",
|
||||
}
|
||||
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseTanModule_basic",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -696,8 +696,8 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !
|
|||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
|
||||
// CHECK: }
|
||||
|
@ -890,15 +890,15 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
|
|||
|
||||
// CHECK-LABEL: @torch.aten.max.dim$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK-DAG: %[[VAL_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[VAL_I2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
||||
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
|
@ -1378,16 +1378,16 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
|
|||
|
||||
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK-DAG: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK-DAG: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK-DAG: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK-DAG: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
// CHECK-LABEL: func.func @scan_1d_inclusive(
|
||||
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
|
||||
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
|
||||
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||
// CHECK: }
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||
func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
|
||||
|
@ -32,8 +32,10 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
|
||||
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
|
||||
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||
|
@ -41,8 +43,6 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||
// CHECK: }
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||
func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
|
||||
|
@ -62,14 +62,14 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
|
||||
// CHECK: }
|
||||
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||
func.func @scatter_update_scalar_1D(
|
||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||
|
@ -90,7 +90,8 @@ func.func @scatter_update_scalar_1D(
|
|||
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||
|
@ -99,7 +100,6 @@ func.func @scatter_update_scalar_1D(
|
|||
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
|
||||
// CHECK: tm_tensor.yield %[[ADD]] : i32
|
||||
// CHECK: }
|
||||
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||
func.func @scatter_add_scalar_1D(
|
||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||
|
|
|
@ -34,5 +34,5 @@ def test_enable_ir_printing():
|
|||
)
|
||||
|
||||
|
||||
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize)
|
||||
# CHECK: // -----// IR Dump After Inliner (inline)
|
||||
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {
|
||||
|
|
Loading…
Reference in New Issue