Extend tm_tensor.scatter op semantic to carry unique_indices attribute

There are cases where the op may update the same indices multiple
times. In this context, we can not parallelize updates. Instead,
we have to execute them sequentially. Adding a boolean attribute
to control the behavior.

Also adding test cases for invalid IR.
pull/606/head
Vivek Khandelwal 2022-02-21 13:10:02 +05:30
parent 7023ee53e8
commit 5dbace239b
4 changed files with 442 additions and 24 deletions

View File

@ -118,9 +118,15 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
The first dim of `updates` and `indices` is identical, since they represent
the number of updates.
The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`.
The first `index_depth` indices are derived from `indices` and the shape of
update value must match the rest shape of `original`.
The rank of the `original`/`result` is at least
`index_depth + rank(%updates) - 1`. The first `index_depth` indices are
derived from `indices` and the shape of update value has the last
rank(%original) - index_depth values match %(originals) last dimensions,
with the previous dims extending from the index offsets.
The unique_indices attribute carries the information whether all the indices
are unique. If there are repeated indices, the first iteration loop will be
marked as reduction.
The shapes definition follows tensorflow operations execept that it force
batch dims to be 1D. See more information in
@ -128,12 +134,14 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
}];
let arguments = (ins
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
Variadic<AnyRankedTensorOrMemRefType>:$outputs
Variadic<AnyRankedTensorOrMemRefType>:$outputs,
DefaultValuedAttr<BoolAttr, "true">:$unique_indices
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
attr-dict `unique_indices` `(` $unique_indices `)`
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];

View File

@ -295,20 +295,50 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
"mismatch in shape of indices and update value at dim#0");
}
auto originalType = op.getOriginalType();
// indexDepth + update dims should match to original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
if (updateType.getRank() - 1 > originalType.getRank()) {
return op.emitOpError(
"mismatch in rank of update value, index depth and original value");
"update value rank exceeds the rank of the original value");
}
for (auto dim : llvm::seq<unsigned>(indexDepth, originalType.getRank())) {
// Offset one because the first dim is the number of updates.
if (updateType.getDimSize(1 + dim - indexDepth) !=
originalType.getDimSize(dim)) {
// indexDepth + update dims should cover the original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
return op.emitOpError(
"index depth and update value does not cover rank of original value");
}
// Validate the non-indexed update dims covier the full slice size of the
// original tensor.
int64_t fullSliceDims = originalType.getRank() - indexDepth;
for (auto it :
llvm::zip(llvm::seq<unsigned>(indexDepth, originalType.getRank()),
llvm::seq<unsigned>(updateType.getRank() - fullSliceDims,
updateType.getRank()))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) !=
originalType.getDimSize(originalDim)) {
return op.emitOpError("mismatch in shape of update value dim#")
<< (1 + dim - indexDepth) << " and original value at dim#" << dim;
<< updateDim << " and original value at dim#" << originalDim;
}
}
// Check that the remaining update indices do not exceed the update length.
int64_t insertDims = originalType.getRank() - updateType.getRank() + 1;
for (auto it : llvm::zip(
llvm::seq<unsigned>(insertDims, indexDepth),
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return op.emitOpError("indexed shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim
<< " " << updateType.getDimSize(updateDim) << " "
<< originalType.getDimSize(originalDim);
}
}
Region &region = op.region();
Block *body = &region.front();
if (body->getNumArguments() != 2) {
@ -349,6 +379,9 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
SmallVector<StringRef> iteratorTypes(getUpdateType().getRank(),
getParallelIteratorTypeName());
if (!unique_indices()) {
iteratorTypes[0] = getReductionIteratorTypeName();
}
return iteratorTypes;
}
@ -373,12 +406,26 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
SmallVector<Value> loadIndices;
loadIndices.push_back(ivs.front());
loadIndices.push_back(Value());
// Populate with empty values.
auto originalTy = original().getType().cast<ShapedType>();
starts.resize(originalTy.getRank(), Value());
auto updateIvs = ivs.drop_front(1);
int64_t offset = starts.size() - updateIvs.size();
for (auto it : llvm::enumerate(updateIvs)) {
starts[it.index() + offset] = it.value();
}
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
starts.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx));
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
if (starts[i]) cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
starts[i] = cast;
}
starts.append(std::next(ivs.begin()), ivs.end());
Value init = b.create<memref::LoadOp>(loc, original(), starts);
BlockAndValueMapping bvm;

View File

@ -105,7 +105,7 @@ func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
func @scatter_update_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -131,7 +131,7 @@ func @scatter_update_scalar_1D(
func @scatter_add_scalar_2D(
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -162,7 +162,7 @@ func @scatter_add_scalar_2D(
func @scatter_update_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -192,7 +192,7 @@ func @scatter_update_slice_2D(
func @scatter_add_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -221,7 +221,7 @@ func @scatter_add_scalar_1D(
func @scatter_add_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -251,7 +251,7 @@ func @scatter_add_slice_2D(
func @scatter_update_scalar_dynamic_1D(
%original: memref<?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
outs(%original : memref<?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -277,7 +277,7 @@ func @scatter_update_scalar_dynamic_1D(
func @scatter_add_scalar_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -308,7 +308,7 @@ func @scatter_add_scalar_dynamic_2D(
func @scatter_update_slice_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?x?xi32>) {
tm_tensor.scatter
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -330,3 +330,38 @@ func @scatter_update_slice_dynamic_2D(
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// -----
func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
tm_tensor.scatter
unique_indices(true)
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
outs(%arg0 : memref<2x64x12xf32>) {
^bb0(%arg3: f32, %arg4: f32):
tm_tensor.yield %arg4 : f32
}
return
}
// CHECK-LABEL: @scatter_partial_slices
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant
// CHECK-DAG: %[[C1:.+]] = arith.constant
// CHECK-DAG: %[[C2:.+]] = arith.constant
// CHECK-DAG: %[[C12:.+]] = arith.constant
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] {
// CHECK-NEXT: %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index
// CHECK-NEXT: %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index
// CHECK-NEXT: %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index
// CHECK-NEXT: %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index
// CHECK-NEXT: %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index
// CHECK-NEXT: %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
// CHECK-NEXT: memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>

View File

@ -0,0 +1,328 @@
// RUN: torch-mlir-dialects-opt -split-input-file -verify-diagnostics %s
func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, memref<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @scatter_extra_outputs(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// expected-error @+1 {{expected number of outputs to be same as the number of results}}
%0, %1 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>, tensor<?x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
outs(%original : memref<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> memref<?x?xf32> {
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'memref<?x?xf32>'}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> memref<?x?xf32>
return %0 : memref<?x?xf32>
}
// -----
func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
tm_tensor.scatter unique_indices(true)
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
outs(%original : memref<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
}
return
}
// -----
func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
tm_tensor.scatter unique_indices(true)
ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
}
return
}
// -----
func @scatter_dim_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<48x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @scatter_dim_mismatch(
%update : tensor<64x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @scatter_dim_mismatch(
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{op update value rank exceeds the rank of the original value}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @scatter_dim_mismatch(
%update : tensor<?x4xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x4xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{expected region to have scalar argument of integer or float types}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi32>) {
^bb0(%arg1: index, %arg2: index):
%1 = arith.addi %arg1, %arg2 : index
%2 = arith.index_cast %1 : index to i32
tm_tensor.yield %2 : i32
} -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// -----
func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi32>) {
^bb0(%arg1: i64, %arg2: i32):
%1 = arith.trunci %arg1 : i64 to i32
%2 = arith.addi %1, %arg2 : i32
tm_tensor.yield %2 : i32
} -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// -----
func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi32>) {
^bb0(%arg1: i32, %arg2: i64):
%1 = arith.trunci %arg2 : i64 to i32
%2 = arith.addi %1, %arg1 : i32
tm_tensor.yield %2 : i32
} -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// -----
func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i32, %arg2: i64):
%1 = arith.extsi %arg1 : i32 to i64
%2 = arith.addi %1, %arg2 : i64
tm_tensor.yield %2 : i64
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}
// -----
func @scatter_region_type_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected region to have two arguments}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3 : i64):
%1 = arith.addi %arg1, %arg2 : i64
tm_tensor.yield %1 : i64
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}
// -----
func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
%1 = arith.addi %arg1, %arg2 : i64
%2 = arith.trunci %1 : i64 to i32
// expected-error @+1 {{mismatch in type of yielded value 'i32' and argument of the region 'i64'}}
tm_tensor.yield %2 : i32
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}
// -----
func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
%1 = arith.addi %arg1, %arg2 : i64
%2 = arith.trunci %1 : i64 to i32
// expected-error @+1 {{expected region to yield a single value}}
tm_tensor.yield %1, %2 : i64, i32
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}
// -----
func @scatter_index_depth_dynamic(
%update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected index depth is static}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x?xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
%1 = arith.addi %arg1, %arg2 : i64
%2 = arith.trunci %1 : i64 to i32
tm_tensor.yield %1, %2 : i64, i32
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}
// -----
func @scatter_original_rank_mismatch(
%update : tensor<?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{op index depth and update value does not cover rank of original value}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
%1 = arith.addi %arg1, %arg2 : i64
%2 = arith.trunci %1 : i64 to i32
tm_tensor.yield %1, %2 : i64, i32
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}