mirror of https://github.com/llvm/torch-mlir
Extend tm_tensor.scatter op semantic to carry unique_indices attribute
There are cases where the op may update the same indices multiple times. In this context, we can not parallelize updates. Instead, we have to execute them sequentially. Adding a boolean attribute to control the behavior. Also adding test cases for invalid IR.pull/606/head
parent
7023ee53e8
commit
5dbace239b
|
@ -118,9 +118,15 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
|||
The first dim of `updates` and `indices` is identical, since they represent
|
||||
the number of updates.
|
||||
|
||||
The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`.
|
||||
The first `index_depth` indices are derived from `indices` and the shape of
|
||||
update value must match the rest shape of `original`.
|
||||
The rank of the `original`/`result` is at least
|
||||
`index_depth + rank(%updates) - 1`. The first `index_depth` indices are
|
||||
derived from `indices` and the shape of update value has the last
|
||||
rank(%original) - index_depth values match %(originals) last dimensions,
|
||||
with the previous dims extending from the index offsets.
|
||||
|
||||
The unique_indices attribute carries the information whether all the indices
|
||||
are unique. If there are repeated indices, the first iteration loop will be
|
||||
marked as reduction.
|
||||
|
||||
The shapes definition follows tensorflow operations execept that it force
|
||||
batch dims to be 1D. See more information in
|
||||
|
@ -128,12 +134,14 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
|||
}];
|
||||
let arguments = (ins
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$outputs
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$outputs,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$unique_indices
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let assemblyFormat = [{
|
||||
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
|
||||
attr-dict `unique_indices` `(` $unique_indices `)`
|
||||
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
|
||||
`outs` `(` $outputs `:` type($outputs) `)`
|
||||
$region (`->` type($results)^)?
|
||||
}];
|
||||
|
|
|
@ -295,20 +295,50 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
|
|||
"mismatch in shape of indices and update value at dim#0");
|
||||
}
|
||||
auto originalType = op.getOriginalType();
|
||||
// indexDepth + update dims should match to original dims. The first dim of
|
||||
// update is the number of updates.
|
||||
if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
|
||||
if (updateType.getRank() - 1 > originalType.getRank()) {
|
||||
return op.emitOpError(
|
||||
"mismatch in rank of update value, index depth and original value");
|
||||
"update value rank exceeds the rank of the original value");
|
||||
}
|
||||
for (auto dim : llvm::seq<unsigned>(indexDepth, originalType.getRank())) {
|
||||
// Offset one because the first dim is the number of updates.
|
||||
if (updateType.getDimSize(1 + dim - indexDepth) !=
|
||||
originalType.getDimSize(dim)) {
|
||||
|
||||
// indexDepth + update dims should cover the original dims. The first dim of
|
||||
// update is the number of updates.
|
||||
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
|
||||
return op.emitOpError(
|
||||
"index depth and update value does not cover rank of original value");
|
||||
}
|
||||
|
||||
// Validate the non-indexed update dims covier the full slice size of the
|
||||
// original tensor.
|
||||
int64_t fullSliceDims = originalType.getRank() - indexDepth;
|
||||
for (auto it :
|
||||
llvm::zip(llvm::seq<unsigned>(indexDepth, originalType.getRank()),
|
||||
llvm::seq<unsigned>(updateType.getRank() - fullSliceDims,
|
||||
updateType.getRank()))) {
|
||||
int64_t originalDim = std::get<0>(it);
|
||||
int64_t updateDim = std::get<1>(it);
|
||||
if (updateType.getDimSize(updateDim) !=
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return op.emitOpError("mismatch in shape of update value dim#")
|
||||
<< (1 + dim - indexDepth) << " and original value at dim#" << dim;
|
||||
<< updateDim << " and original value at dim#" << originalDim;
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the remaining update indices do not exceed the update length.
|
||||
int64_t insertDims = originalType.getRank() - updateType.getRank() + 1;
|
||||
for (auto it : llvm::zip(
|
||||
llvm::seq<unsigned>(insertDims, indexDepth),
|
||||
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
|
||||
int64_t originalDim = std::get<0>(it);
|
||||
int64_t updateDim = std::get<1>(it);
|
||||
if (updateType.getDimSize(updateDim) >
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return op.emitOpError("indexed shape of update value dim#")
|
||||
<< updateDim << " exceeds original value at dim#" << originalDim
|
||||
<< " " << updateType.getDimSize(updateDim) << " "
|
||||
<< originalType.getDimSize(originalDim);
|
||||
}
|
||||
}
|
||||
|
||||
Region ®ion = op.region();
|
||||
Block *body = ®ion.front();
|
||||
if (body->getNumArguments() != 2) {
|
||||
|
@ -349,6 +379,9 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
|
|||
SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
|
||||
SmallVector<StringRef> iteratorTypes(getUpdateType().getRank(),
|
||||
getParallelIteratorTypeName());
|
||||
if (!unique_indices()) {
|
||||
iteratorTypes[0] = getReductionIteratorTypeName();
|
||||
}
|
||||
return iteratorTypes;
|
||||
}
|
||||
|
||||
|
@ -373,12 +406,26 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
|||
SmallVector<Value> loadIndices;
|
||||
loadIndices.push_back(ivs.front());
|
||||
loadIndices.push_back(Value());
|
||||
|
||||
// Populate with empty values.
|
||||
auto originalTy = original().getType().cast<ShapedType>();
|
||||
starts.resize(originalTy.getRank(), Value());
|
||||
auto updateIvs = ivs.drop_front(1);
|
||||
|
||||
int64_t offset = starts.size() - updateIvs.size();
|
||||
for (auto it : llvm::enumerate(updateIvs)) {
|
||||
starts[it.index() + offset] = it.value();
|
||||
}
|
||||
|
||||
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
|
||||
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
|
||||
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
||||
starts.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx));
|
||||
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
|
||||
|
||||
if (starts[i]) cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
|
||||
starts[i] = cast;
|
||||
}
|
||||
starts.append(std::next(ivs.begin()), ivs.end());
|
||||
|
||||
Value init = b.create<memref::LoadOp>(loc, original(), starts);
|
||||
|
||||
BlockAndValueMapping bvm;
|
||||
|
|
|
@ -105,7 +105,7 @@ func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
|
|||
func @scatter_update_scalar_1D(
|
||||
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||
outs(%original : memref<8xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -131,7 +131,7 @@ func @scatter_update_scalar_1D(
|
|||
func @scatter_add_scalar_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -162,7 +162,7 @@ func @scatter_add_scalar_2D(
|
|||
func @scatter_update_slice_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||
%updates: memref<2x3xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -192,7 +192,7 @@ func @scatter_update_slice_2D(
|
|||
func @scatter_add_scalar_1D(
|
||||
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||
outs(%original : memref<8xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -221,7 +221,7 @@ func @scatter_add_scalar_1D(
|
|||
func @scatter_add_slice_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||
%updates: memref<2x3xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -251,7 +251,7 @@ func @scatter_add_slice_2D(
|
|||
func @scatter_update_scalar_dynamic_1D(
|
||||
%original: memref<?xi32>, %indices: memref<?x1xi32>,
|
||||
%updates: memref<?xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
|
||||
outs(%original : memref<?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -277,7 +277,7 @@ func @scatter_update_scalar_dynamic_1D(
|
|||
func @scatter_add_scalar_dynamic_2D(
|
||||
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
|
||||
%updates: memref<?xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
|
||||
outs(%original : memref<?x?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -308,7 +308,7 @@ func @scatter_add_scalar_dynamic_2D(
|
|||
func @scatter_update_slice_dynamic_2D(
|
||||
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
|
||||
%updates: memref<?x?xi32>) {
|
||||
tm_tensor.scatter
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
|
||||
outs(%original : memref<?x?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -330,3 +330,38 @@ func @scatter_update_slice_dynamic_2D(
|
|||
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
|
||||
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
|
||||
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
|
||||
tm_tensor.scatter
|
||||
unique_indices(true)
|
||||
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
|
||||
outs(%arg0 : memref<2x64x12xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32):
|
||||
tm_tensor.yield %arg4 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @scatter_partial_slices
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant
|
||||
// CHECK-DAG: %[[C12:.+]] = arith.constant
|
||||
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
|
||||
// CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
|
||||
// CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] {
|
||||
// CHECK-NEXT: %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32>
|
||||
// CHECK-NEXT: %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index
|
||||
// CHECK-NEXT: %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32>
|
||||
// CHECK-NEXT: %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index
|
||||
// CHECK-NEXT: %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index
|
||||
// CHECK-NEXT: %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32>
|
||||
// CHECK-NEXT: %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index
|
||||
// CHECK-NEXT: %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index
|
||||
// CHECK-NEXT: %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
|
||||
// CHECK-NEXT: memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
|
||||
|
|
|
@ -0,0 +1,328 @@
|
|||
// RUN: torch-mlir-dialects-opt -split-input-file -verify-diagnostics %s
|
||||
|
||||
func @scatter_mixed_tensor_memref(
|
||||
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_mixed_tensor_memref(
|
||||
%update : tensor<?x?xf32>, %indices : memref<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, memref<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_extra_outputs(
|
||||
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
// expected-error @+1 {{expected number of outputs to be same as the number of results}}
|
||||
%0, %1 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>, tensor<?x?xf32>
|
||||
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_mixed_tensor_memref(
|
||||
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : memref<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : memref<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_mixed_tensor_memref(
|
||||
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> memref<?x?xf32> {
|
||||
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'memref<?x?xf32>'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> memref<?x?xf32>
|
||||
return %0 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_mixed_tensor_memref(
|
||||
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : memref<?x?xf32>) {
|
||||
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : memref<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_mixed_tensor_memref(
|
||||
%update : memref<?x?xf32>, %indices : memref<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) {
|
||||
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_dim_mismatch(
|
||||
%update : tensor<?x?xf32>, %indices : tensor<48x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<48x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_dim_mismatch(
|
||||
%update : tensor<64x?xf32>, %indices : tensor<48x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_dim_mismatch(
|
||||
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{op update value rank exceeds the rank of the original value}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_dim_mismatch(
|
||||
%update : tensor<?x4xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x4xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_region_type_mismatch(
|
||||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// expected-error @+1 {{expected region to have scalar argument of integer or float types}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi32>) {
|
||||
^bb0(%arg1: index, %arg2: index):
|
||||
%1 = arith.addi %arg1, %arg2 : index
|
||||
%2 = arith.index_cast %1 : index to i32
|
||||
tm_tensor.yield %2 : i32
|
||||
} -> tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_region_type_mismatch(
|
||||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi32>) {
|
||||
^bb0(%arg1: i64, %arg2: i32):
|
||||
%1 = arith.trunci %arg1 : i64 to i32
|
||||
%2 = arith.addi %1, %arg2 : i32
|
||||
tm_tensor.yield %2 : i32
|
||||
} -> tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_region_type_mismatch(
|
||||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi32>) {
|
||||
^bb0(%arg1: i32, %arg2: i64):
|
||||
%1 = arith.trunci %arg2 : i64 to i32
|
||||
%2 = arith.addi %1, %arg1 : i32
|
||||
tm_tensor.yield %2 : i32
|
||||
} -> tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_region_type_mismatch(
|
||||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i32, %arg2: i64):
|
||||
%1 = arith.extsi %arg1 : i32 to i64
|
||||
%2 = arith.addi %1, %arg2 : i64
|
||||
tm_tensor.yield %2 : i64
|
||||
} -> tensor<?x?xi64>
|
||||
return %0 : tensor<?x?xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_region_type_mismatch(
|
||||
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{expected region to have two arguments}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64, %arg3 : i64):
|
||||
%1 = arith.addi %arg1, %arg2 : i64
|
||||
tm_tensor.yield %1 : i64
|
||||
} -> tensor<?x?xi64>
|
||||
return %0 : tensor<?x?xi64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_yield_mismatch(
|
||||
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
%1 = arith.addi %arg1, %arg2 : i64
|
||||
%2 = arith.trunci %1 : i64 to i32
|
||||
// expected-error @+1 {{mismatch in type of yielded value 'i32' and argument of the region 'i64'}}
|
||||
tm_tensor.yield %2 : i32
|
||||
} -> tensor<?x?xi64>
|
||||
return %0 : tensor<?x?xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_yield_mismatch(
|
||||
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
%1 = arith.addi %arg1, %arg2 : i64
|
||||
%2 = arith.trunci %1 : i64 to i32
|
||||
// expected-error @+1 {{expected region to yield a single value}}
|
||||
tm_tensor.yield %1, %2 : i64, i32
|
||||
} -> tensor<?x?xi64>
|
||||
return %0 : tensor<?x?xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_index_depth_dynamic(
|
||||
%update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{expected index depth is static}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x?xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
%1 = arith.addi %arg1, %arg2 : i64
|
||||
%2 = arith.trunci %1 : i64 to i32
|
||||
tm_tensor.yield %1, %2 : i64, i32
|
||||
} -> tensor<?x?xi64>
|
||||
return %0 : tensor<?x?xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_original_rank_mismatch(
|
||||
%update : tensor<?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{op index depth and update value does not cover rank of original value}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
%1 = arith.addi %arg1, %arg2 : i64
|
||||
%2 = arith.trunci %1 : i64 to i32
|
||||
tm_tensor.yield %1, %2 : i64, i32
|
||||
} -> tensor<?x?xi64>
|
||||
return %0 : tensor<?x?xi64>
|
||||
}
|
Loading…
Reference in New Issue