[torch] Rework lowering to tm_tensor.scatter to stop serialization (#2940)

We collapsed and broadcasted scatter indices to a single element
version. We should instead upport `tm_tensor.scatter`s support for
multiple indices and the implicitly broadcasted behavior. This avoids
the serialization and materializing a needlessly large indices tensor.
pull/2960/head
Rob Suderman 2024-02-27 11:46:57 -08:00 committed by GitHub
parent d628b5fd06
commit e30a083aff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 380 additions and 441 deletions

View File

@ -137,6 +137,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
let arguments = (ins
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
Variadic<AnyRankedTensorOrMemRefType>:$outputs,
DenseI64ArrayAttr:$dimension_map,
DefaultValuedAttr<BoolAttr, "true">:$unique_indices
);
let results = (outs Variadic<AnyRankedTensor>:$results);

View File

@ -200,15 +200,30 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
scatterInputsVector[indexType.getRank()]);
}
static llvm::SmallVector<int64_t> createDefaultDimMap(Value indices) {
llvm::SmallVector<int64_t> dmap;
if (auto iTy = dyn_cast<BaseTensorType>(indices.getType()))
dmap.resize(iTy.getSizes()[1]);
if (auto iTy = dyn_cast<RankedTensorType>(indices.getType()))
dmap.resize(iTy.getDimSize(1));
for (int i = 0, s = dmap.size(); i < s; ++i)
dmap[i] = i;
return dmap;
}
static Value createTMTensorScatterOp(
OpBuilder &b, Location loc, Value updates, Value indices, Value original,
bool uniqueIndices,
llvm::ArrayRef<int64_t> dimensionsMap, bool uniqueIndices,
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap);
auto originalTensorType = original.getType().cast<RankedTensorType>();
Type originalElementType = originalTensorType.getElementType();
auto scatterOp = b.create<TMTensor::ScatterOp>(
loc, originalTensorType, ValueRange{updates, indices},
ValueRange{original}, uniqueIndices);
ValueRange{original}, dimensionsMapAttr, uniqueIndices);
Region &scatterOpRegion = scatterOp.getRegion();
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
@ -334,7 +349,7 @@ public:
src, dim);
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, updates, indices, self,
/*uniqueIndices=*/false,
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value updatesElement,
Value inputElement) {
b.create<TMTensor::YieldOp>(loc, updatesElement);
@ -455,7 +470,7 @@ public:
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, updatesTensor, indices, bincountTensor,
/*uniqueIndices=*/false,
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value _, Value bincountElem) {
Value add = b.create<arith::AddIOp>(loc, bincountElem, constantOne);
b.create<TMTensor::YieldOp>(loc, add);
@ -466,235 +481,200 @@ public:
};
} // namespace
// """Create a map from each dimension of the input tensor to the
// subspace that dimension corresponds to in the result shape one gets
// from indexing the tensor with the optional index tensors.
//
// Note: Index tensors are first broadcasted to a common shape before
// creating the mapping. So the index of every index tensor will map to
// the same dimensions in the result shape.
//
// For example:
// indices = [None, None, torch.randint(4, (6, 1)), torch.randint(5, (7,))]
// indexBroadcastShapeValue = [6, 7]
// map = {0: [0], 1: [1], 2: [2, 3], 3: [2, 3]}
static SmallVector<SmallVector<int64_t>>
getInputShapeToOutputShapeMap(SmallVector<Value> optionalIndices,
SmallVector<Value> indexBroadcastShapeValue) {
SmallVector<Value> indices;
for (Value index : optionalIndices) {
if (!index.getType().isa<Torch::NoneType>())
indices.push_back(index);
}
namespace {
unsigned broadcastRank = indexBroadcastShapeValue.size();
unsigned numIndexTensors = indices.size();
int64_t indexOfFirstIndexTensor = -1;
SmallVector<SmallVector<int64_t>> result;
for (unsigned i = 0; i < optionalIndices.size(); i++) {
if (optionalIndices[i].getType().isa<Torch::NoneType>()) {
unsigned val = i;
if (indexOfFirstIndexTensor >= 0)
val += broadcastRank - numIndexTensors;
result.push_back({val});
} else {
if (indexOfFirstIndexTensor < 0)
indexOfFirstIndexTensor = i;
SmallVector<int64_t> outputIndices;
for (unsigned j = indexOfFirstIndexTensor;
j < (indexOfFirstIndexTensor + broadcastRank); j++)
outputIndices.push_back(j);
result.push_back(outputIndices);
}
}
return result;
}
static std::tuple<SmallVector<Value>, SmallVector<int64_t>>
getIndicesFinalShape(ConversionPatternRewriter &rewriter, Location loc,
Value input, SmallVector<Value> optionalIndices,
SmallVector<int64_t> inputShapeInt,
SmallVector<Value> inputShapeValue,
SmallVector<int64_t> indexBroadcastShapeInt,
SmallVector<Value> indexBroadcastShapeValue) {
SmallVector<Value> result;
SmallVector<int64_t> resultInt;
bool handledIndexTensorSpace = false;
for (unsigned i = 0; i < inputShapeValue.size(); i++) {
if (optionalIndices[i].getType().isa<Torch::NoneType>()) {
result.push_back(inputShapeValue[i]);
resultInt.push_back(inputShapeInt[i]);
} else {
if (!handledIndexTensorSpace) {
handledIndexTensorSpace = true;
for (unsigned j = 0; j < indexBroadcastShapeValue.size(); j++) {
result.push_back(indexBroadcastShapeValue[j]);
resultInt.push_back(indexBroadcastShapeInt[j]);
}
}
}
}
return std::make_tuple(result, resultInt);
}
static FailureOr<Value>
getScatterIndices(Aten_IndexPutImplOp op, ConversionPatternRewriter &rewriter,
Type indicesDtype, SmallVector<Value> optionalIndices,
SmallVector<int64_t> indexBroadcastShapeInt,
SmallVector<Value> indexBroadcastShapeValue) {
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
Value input = op.getSelf();
SmallVector<SmallVector<int64_t>> shapeMap =
getInputShapeToOutputShapeMap(optionalIndices, indexBroadcastShapeValue);
SmallVector<int64_t> inputShapeInt{
input.getType().cast<BaseTensorType>().getSizes()};
int64_t inputRank = inputShapeInt.size();
SmallVector<Value> inputShapeValue;
for (unsigned i = 0; i < inputShapeInt.size(); i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
inputShapeValue.push_back(
rewriter.createOrFold<AtenSizeIntOp>(loc, input, dim));
}
auto finalShapeResult = getIndicesFinalShape(
rewriter, loc, input, optionalIndices, inputShapeInt, inputShapeValue,
indexBroadcastShapeInt, indexBroadcastShapeValue);
SmallVector<Value> finalShapeValue = std::get<0>(finalShapeResult);
SmallVector<int64_t> finalShapeInt = std::get<1>(finalShapeResult);
Value torchCstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
OpBuilder b) {
llvm::SmallVector<Value> indices(indicesRef);
// Declare commonly used constants up front:
Value torchCstZero =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
Value torchCstOne =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
Value torchCstNegOne =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(-1));
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
indexBroadcastShapeValue);
// Calculating index count.
int64_t indexCount = 1;
if (llvm::all_of(finalShapeInt,
[](int64_t shape) { return shape != kUnknownSize; })) {
for (int64_t i : finalShapeInt)
indexCount *= i;
} else {
indexCount = kUnknownSize;
// Determine the broadcast sizes and materialize missing implicit end
// dimensions:
int64_t indicesRank = 0;
for (auto index : indices) {
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
int64_t rank = indexTy.getSizes().size();
indicesRank = std::max(rank, indicesRank);
}
Value indexCountValue = finalShapeValue[0];
for (unsigned i = 1; i < finalShapeValue.size(); i++)
indexCountValue =
rewriter.create<AtenMulIntOp>(loc, indexCountValue, finalShapeValue[i]);
auto maxDim = [](int64_t dim0, int64_t dim1) {
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
return Torch::kUnknownSize;
return std::max(dim0, dim1);
};
ValueTensorType flattenIndicesType =
ValueTensorType::get(context, llvm::ArrayRef(indexCount), indicesDtype);
Value flattenEndDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(finalShapeInt.size() - 1));
llvm::SmallVector<Value> broadcastSizes(indicesRank, torchCstOne);
llvm::SmallVector<int64_t> broadcastShape(indicesRank, 0);
for (auto index : indices) {
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
auto shape = indexTy.getSizes();
int32_t rank = shape.size();
SmallVector<Value> broadcastedIndices;
for (unsigned i = 0; i < optionalIndices.size(); i++) {
Value broadcastedIndexTensor;
if (optionalIndices[i].getType().isa<Torch::NoneType>()) {
Value torchCstDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
Value inputDim = rewriter.create<AtenSizeIntOp>(loc, input, torchCstDim);
ValueTensorType tensorType = ValueTensorType::get(
context, llvm::ArrayRef(inputShapeInt[i]), indicesDtype);
broadcastedIndexTensor = rewriter.create<AtenArangeStartStepOp>(
loc, tensorType, /*start=*/torchCstZero, /*end=*/inputDim,
/*step=*/torchCstOne,
/*dtype=*/torchCstNone,
/*layout=*/torchCstNone,
/*device=*/torchCstNone,
/*pin_memory=*/torchCstNone);
} else {
ValueTensorType tensorType = ValueTensorType::get(
context, llvm::ArrayRef(indexBroadcastShapeInt), indicesDtype);
broadcastedIndexTensor = rewriter.create<AtenBroadcastToOp>(
loc, tensorType, optionalIndices[i], indexBroadcastShapeTorchList);
for (int32_t j = 0; j < rank; ++j) {
Value dim = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(j));
auto sizeOp = b.create<Torch::AtenSizeIntOp>(loc, index, dim);
auto size = shape[j];
int32_t idx = broadcastShape.size() - rank + j;
broadcastSizes[idx] =
b.create<Torch::PrimMaxIntOp>(loc, sizeOp, broadcastSizes[idx]);
broadcastShape[idx] = maxDim(size, broadcastShape[idx]);
}
// spotlight_indices(final_shape, shape_map[i]):
// Turn all values in `final_shape` to `1` except for those with index in
// `indices`.
// for j in range(len(final_shape)):
// if j not in indices:
// final_shape[j] = 1
// This is equivalent to unsqueezing the index tensor at the dimension `j`
// not in indices.
for (unsigned j = 0; j < finalShapeInt.size(); j++) {
if (llvm::find(shapeMap[i], j) == shapeMap[i].end()) {
Value unsqueezeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(j));
auto unsqueezedInfo =
unsqueezeTensor(rewriter, op, broadcastedIndexTensor,
/*dim=*/unsqueezeDim);
if (failed(unsqueezedInfo)) {
return rewriter.notifyMatchFailure(
op, "cannot generate unsqueeze tensor op");
}
broadcastedIndexTensor = *unsqueezedInfo;
}
}
// Performing broadcast to final shape.
Value broadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
finalShapeValue);
ValueTensorType broadcastTensorType = ValueTensorType::get(
context, llvm::ArrayRef(finalShapeInt), indicesDtype);
broadcastedIndexTensor = rewriter.create<AtenBroadcastToOp>(
loc, broadcastTensorType, broadcastedIndexTensor,
broadcastShapeTorchList);
// Flattening the tensor.
broadcastedIndexTensor = rewriter.create<AtenFlattenUsingIntsOp>(
loc, flattenIndicesType, broadcastedIndexTensor, torchCstZero,
flattenEndDim);
broadcastedIndices.push_back(broadcastedIndexTensor);
}
// Stacking broadcasted indices.
Value scatterIndices;
// The operation torch.stack([a, b], dim=0) is decomposed into:
// torch.cat([a.unsqueeze(dim=0), b.unsqueeze(dim=0)], dim=0)
// Unsqueeze all tensors before concatenating.
SmallVector<Value> unsqueezedIndexTensors;
for (Value tensor : broadcastedIndices) {
auto unsqueezedInfo =
unsqueezeTensor(rewriter, op, tensor, /*dim=*/torchCstZero);
if (failed(unsqueezedInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor op");
}
unsqueezedIndexTensors.push_back(*unsqueezedInfo);
auto mulDim = [](int64_t dim0, int64_t dim1) {
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
return Torch::kUnknownSize;
return dim0 * dim1;
};
int64_t scatterBatchCount = 1;
for (auto dim : broadcastShape) {
scatterBatchCount = mulDim(scatterBatchCount, dim);
}
// Broadcast together and flatten to batch values:
Value broadcastSizeList = b.create<PrimListConstructOp>(
loc, Torch::ListType::get(b.getType<Torch::IntType>()), broadcastSizes);
for (Value &index : indices) {
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
auto expandTy = b.getType<Torch::ValueTensorType>(
broadcastShape, indexTy.getOptionalDtype());
index = b.create<Torch::AtenBroadcastToOp>(loc, expandTy, index,
broadcastSizeList);
auto flattenTy = b.getType<Torch::ValueTensorType>(
scatterBatchCount, indexTy.getOptionalDtype());
index = b.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenTy, index, torchCstZero, torchCstNegOne);
}
// Unsqueeze so we have a 1 dim to concat along:
for (Value &tensor : indices) {
auto btt = cast<Torch::BaseTensorType>(tensor.getType());
if (!btt.hasSizes())
return nullptr;
llvm::SmallVector<int64_t> shape(btt.getSizes());
shape.push_back(1);
auto unsqueezeTy = b.getType<Torch::ValueTensorType>(shape, btt.getDtype());
Value unsqueezed =
b.create<AtenUnsqueezeOp>(loc, unsqueezeTy, tensor, torchCstOne);
tensor = unsqueezed;
}
BaseTensorType unsqueezedTensorType =
unsqueezedIndexTensors[0].getType().cast<BaseTensorType>();
Value concatIndicesTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(unsqueezedTensorType), unsqueezedIndexTensors);
ValueTensorType concatIndicesType = ValueTensorType::get(
context, llvm::ArrayRef({inputRank, indexCount}), indicesDtype);
scatterIndices = rewriter.create<AtenCatOp>(
loc, concatIndicesType, concatIndicesTorchList, torchCstZero);
ValueTensorType transposedIndicesType = ValueTensorType::get(
context, llvm::ArrayRef({indexCount, inputRank}), indicesDtype);
scatterIndices = rewriter.create<AtenTransposeIntOp>(
loc, transposedIndicesType, scatterIndices, torchCstZero, torchCstOne);
return scatterIndices;
indices[0].getType().cast<BaseTensorType>();
Value indicesTorchList = b.create<PrimListConstructOp>(
loc, Torch::ListType::get(unsqueezedTensorType), indices);
llvm::SmallVector<int64_t, 2> concatShape{
unsqueezedTensorType.getSizes()[0], static_cast<int64_t>(indices.size())};
ValueTensorType concatIndicesType = b.getType<ValueTensorType>(
llvm::ArrayRef(concatShape), unsqueezedTensorType.getDtype());
return b.create<AtenCatOp>(loc, concatIndicesType, indicesTorchList,
torchCstOne);
}
// Helper that collapses the batch dimensions together and moves it to the front
// of the array.
static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
int64_t count, OpBuilder b) {
if (batch == 0 && count == 1)
return values;
auto valuesTy = cast<Torch::ValueTensorType>(values.getType());
auto inShape = valuesTy.getSizes();
llvm::SmallVector<int64_t> outShape;
llvm::SmallVector<Value> outDims;
// We need a length-1 dim at the start to transpose the batch to:
if (batch != 0) {
outDims.push_back(b.create<Torch::ConstantIntOp>(loc, 1));
outShape.push_back(1);
}
// Dimensions before the batch stay the same:
for (int i = 0; i <= batch; i++) {
auto k = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
auto dim = b.create<Torch::AtenSizeIntOp>(loc, values, k);
outDims.push_back(dim);
outShape.push_back(inShape[i]);
}
auto mulI = [](int64_t dim0, int64_t dim1) {
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
return Torch::kUnknownSize;
return dim0 * dim1;
};
// Determine the collapse size of the batch dimension:
for (int i = 1; i < count; i++) {
outShape.back() = mulI(outShape.back(), inShape[batch + i]);
auto k =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(batch + i));
auto dim = b.create<Torch::AtenSizeIntOp>(loc, values, k);
outDims.back() = b.create<Torch::AtenMulIntOp>(loc, dim, outDims.back());
}
// Add the dimensions after the batch dims:
for (int i = batch + count, s = inShape.size(); i < s; ++i) {
auto k = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
auto dim = b.create<Torch::AtenSizeIntOp>(loc, values, k);
outDims.push_back(dim);
outShape.push_back(inShape[i]);
}
Value outDimsList = b.create<PrimListConstructOp>(
loc, Torch::ListType::get(b.getType<Torch::IntType>()), outDims);
valuesTy =
b.getType<Torch::ValueTensorType>(outShape, valuesTy.getOptionalDtype());
values = b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
if (batch == 0)
return values;
// Batch is already at the front, no need to transpose:
std::swap(outDims[0], outDims[batch + 1]);
std::swap(outShape[0], outShape[batch + 1]);
Value dim0 = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
Value dimB =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(batch + 1));
valuesTy =
b.getType<Torch::ValueTensorType>(outShape, valuesTy.getOptionalDtype());
values =
b.create<Torch::AtenTransposeIntOp>(loc, valuesTy, values, dim0, dimB);
outDims.clear();
outShape.clear();
auto transposeShape = valuesTy.getSizes();
int64_t transposeRank = transposeShape.size();
for (int i = 0; i < transposeRank; ++i) {
if (i == batch + 1)
continue;
Value k = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
outDims.push_back(b.create<AtenSizeIntOp>(loc, values, k));
outShape.push_back(transposeShape[i]);
}
valuesTy =
b.getType<Torch::ValueTensorType>(outShape, valuesTy.getOptionalDtype());
outDimsList = b.create<PrimListConstructOp>(
loc, Torch::ListType::get(b.getType<Torch::IntType>()), outDims);
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
}
namespace {
class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
public:
@ -706,11 +686,11 @@ public:
return failure();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.getSelf();
Value values = adaptor.getValues();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
Value input = op.getSelf();
Value values = op.getValues();
auto inputType = cast<ValueTensorType>(input.getType());
auto valuesType = cast<ValueTensorType>(values.getType());
int64_t inputRank = inputType.getSizes().size();
auto valuesTensorType = op.getValues().getType().cast<BaseTensorType>();
auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
@ -737,190 +717,107 @@ public:
op, "Expected accumulate to be constant bool.");
// The element type of the `input` and `values` should be same.
if (inputType.getElementType() != valuesType.getElementType())
if (inputType.getDtype() != valuesType.getDtype())
return rewriter.notifyMatchFailure(
op, "Input element type should be same as the values element type.");
SmallVector<Value> optionalIndicesList;
getListConstructElements(op.getIndices(), optionalIndicesList);
int64_t optionalIndicesCount = optionalIndicesList.size();
// The size of the list of the index tensors should not be greater than the
// input rank.
if ((int64_t)optionalIndicesList.size() > inputRank)
if (optionalIndicesCount > inputRank)
return rewriter.notifyMatchFailure(
op, "Indices list size should not be greater than the input rank.");
Value torchCstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
unsigned sizeOptionalIndicesList = optionalIndicesList.size();
SmallVector<int64_t> nonNoneIndexTensorDim;
unsigned numNonNoneIndices;
if (sizeOptionalIndicesList == 0)
if (optionalIndicesCount == 0)
return rewriter.notifyMatchFailure(op, "Indices list must not be empty.");
for (unsigned i = 0; i < optionalIndicesList.size(); i++) {
if (!optionalIndicesList[i].getType().isa<Torch::NoneType>()) {
nonNoneIndexTensorDim.push_back(i);
}
// Filter to available indices and get the indicesMap:
SmallVector<Value> indicesList;
SmallVector<int64_t> indicesMap;
int64_t numBatchDims = 0;
for (int i = 0, s = optionalIndicesList.size(); i < s; ++i) {
if (isa<Torch::NoneType>(optionalIndicesList[i].getType()))
continue;
indicesList.push_back(optionalIndicesList[i]);
indicesMap.push_back(i);
auto indexTy = cast<ValueTensorType>(indicesList.back().getType());
numBatchDims = std::max(static_cast<int64_t>(indexTy.getSizes().size()),
numBatchDims);
}
numNonNoneIndices = nonNoneIndexTensorDim.size();
if (numNonNoneIndices > 2) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non none index tensors less than or equal to 2 "
"supported only");
} else if (numNonNoneIndices == 2 &&
nonNoneIndexTensorDim[0] != nonNoneIndexTensorDim[1] - 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: case of 2 non none index tensors is supported "
"only when both the tensors are along consecutive dimensions");
}
// Value broadcasting semantics require batch dimensions to be up front if
// the indices are not sequential, otherwise they are sequentially at their
// location:
int64_t batchDim = 0;
for (int s = optionalIndicesList.size(); batchDim < s; ++batchDim)
if (!isa<Torch::NoneType>(optionalIndicesList[batchDim].getType()))
break;
// Padding the indices list with none values.
if (sizeOptionalIndicesList < inputRank) {
for (unsigned i = 0; i < (inputRank - sizeOptionalIndicesList); i++)
optionalIndicesList.push_back(torchCstNone);
}
int64_t nextNone = batchDim;
for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone)
if (isa<Torch::NoneType>(optionalIndicesList[nextNone].getType()))
break;
SmallVector<int64_t> indexBroadcastShapeInt{
optionalIndicesList[nonNoneIndexTensorDim[0]]
.getType()
.cast<BaseTensorType>()
.getSizes()};
SmallVector<Value> indexBroadcastShapeValue;
if (numNonNoneIndices == 2) {
computeBroadcastShape(rewriter, loc,
optionalIndicesList[nonNoneIndexTensorDim[0]],
optionalIndicesList[nonNoneIndexTensorDim[1]],
indexBroadcastShapeInt, indexBroadcastShapeValue);
} else {
// It means there's only one index tensor and broadcast shape is same as
// that index tensor' shape.
for (unsigned i = 0; i < indexBroadcastShapeInt.size(); i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
indexBroadcastShapeValue.push_back(rewriter.createOrFold<AtenSizeIntOp>(
loc, optionalIndicesList[nonNoneIndexTensorDim[0]], dim));
}
}
for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone)
if (!isa<Torch::NoneType>(optionalIndicesList[nextNone].getType()))
batchDim = 0;
Type indicesDtype = optionalIndicesList[nonNoneIndexTensorDim[0]]
.getType()
.cast<BaseTensorType>()
.getDtype();
// Indices are extended, catted, and collapsed into a [batch, depth] tensor:
Value indices = combinePutIndices(loc, indicesList, rewriter);
// This implementation is done to get the scatter indices:
// Bove batch dimensions to the front and collapse into a single dim:
values =
collapseAndMoveBatchDims(loc, values, batchDim, numBatchDims, rewriter);
valuesType = cast<Torch::ValueTensorType>(values.getType());
// def get_broadcast_shape(tensors):
// return list(torch.broadcast_tensors(*tensors)[0].shape)
// def get_input_shape_to_output_shape_map(optional_index_tensors:
// list[Optional[torch.Tensor]]):
// index_tensors = list(filter(lambda x: x is not None,
// optional_index_tensors)) broadcast_rank =
// len(get_broadcast_shape(index_tensors)) num_of_index_tensors =
// len(index_tensors) index_of_first_index_tensor: Optional[int] = None
// result = {}
// for i, index in enumerate(optional_index_tensors):
// if index is None:
// val = i
// if index_of_first_index_tensor is not None:
// val += broadcast_rank - num_of_index_tensors
// result[i] = [val]
// else:
// if index_of_first_index_tensor is None:
// index_of_first_index_tensor = i
// output_indices = list(range(index_of_first_index_tensor,
// index_of_first_index_tensor +
// broadcast_rank))
// result[i] = output_indices
// return result
// def spotlight_indices(shape, indices: list[int]):
// """Turn all values in `shape` to `1` except for those with index in
// `indices`.""" shape = shape.copy() for i in range(len(shape)):
// if i not in indices:
// shape[i] = 1
// return shape
// def get_final_shape(input, optional_index_tensors:
// list[Optional[torch.Tensor]]):
// index_tensors = list(filter(lambda x: x is not None,
// optional_index_tensors)) index_tensors_broadcast_shape =
// get_broadcast_shape(index_tensors) result = []
// handled_index_tensor_space = False
// for e, i in enumerate(input.shape):
// if optional_index_tensors[e] is None:
// result.append(i)
// else:
// if not handled_index_tensor_space:
// handled_index_tensor_space = True
// result += index_tensors_broadcast_shape
// return result
// def get_scatter_indices(input, optional_index_tensors:
// list[Optional[torch.Tensor]]):
// assert len(input.size()) == len(optional_index_tensors), "Pad indices
// with None" shape_map =
// get_input_shape_to_output_shape_map(optional_index_tensors)
// index_tensors = list(filter(lambda x: x is not None,
// optional_index_tensors)) index_tensors_broadcast_shape =
// get_broadcast_shape(index_tensors) final_shape =
// get_final_shape(input, optional_index_tensors)
// broadcasted_index_tensors = []
// for e, optional_index_tensor in enumerate(optional_index_tensors):
// if optional_index_tensor is None:
// tensor_to_broadcast = torch.arange(0, input.size(e))
// else:
// tensor_to_broadcast =
// optional_index_tensor.broadcast_to(index_tensors_broadcast_shape)
// broadcasted_index_tensor = \
// tensor_to_broadcast.reshape(spotlight_indices(final_shape, shape_map[e]))\
// .broadcast_to(final_shape)\
// .flatten()
// broadcasted_index_tensors.append(broadcasted_index_tensor)
// return torch.stack(broadcasted_index_tensors, dim=0).t()
auto scatterIndicesInfo =
getScatterIndices(op, rewriter, indicesDtype, optionalIndicesList,
indexBroadcastShapeInt, indexBroadcastShapeValue);
if (failed(scatterIndicesInfo)) {
return rewriter.notifyMatchFailure(
op, "cannot generate scatter indices for index put op");
}
Value indexTensor = *scatterIndicesInfo;
// Flattening the values tensor.
Value torchCstZero = rewriter.create<Torch::ConstantIntOp>(
// Materialize out the length-1 dimensions:
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value flattenedValuesTensorLastDim = rewriter.create<Torch::ConstantIntOp>(
loc,
rewriter.getI64IntegerAttr(valuesTensorType.getSizes().size() - 1));
SmallVector<int64_t> valuesShapeInt{valuesTensorType.getSizes()};
int64_t valuesCount = 1;
if (llvm::all_of(valuesShapeInt,
[](int64_t shape) { return shape != kUnknownSize; })) {
for (int64_t i : valuesShapeInt)
valuesCount *= i;
} else {
valuesCount = kUnknownSize;
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
llvm::SmallVector<int64_t> valuesShape{valuesType.getSizes().front()};
llvm::SmallVector<Value> valuesDims;
valuesDims.push_back(
rewriter.create<Torch::AtenSizeIntOp>(loc, values, zero));
int vDim = 1;
for (int i = 0, s = inputType.getSizes().size(); i < s; ++i) {
if (i < optionalIndicesCount &&
!isa<Torch::NoneType>(optionalIndicesList[i].getType())) {
valuesDims.push_back(one);
valuesShape.push_back(1);
continue;
}
Value k = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(vDim));
valuesDims.push_back(
rewriter.create<Torch::AtenSizeIntOp>(loc, values, k));
valuesShape.push_back(inputType.getSizes()[i]);
vDim++;
}
auto flattenedValuesTensorType = ValueTensorType::get(
context, llvm::ArrayRef(valuesCount), valuesTensorType.getDtype());
Value flattenedValuesTensor = rewriter.create<AtenFlattenUsingIntsOp>(
loc, flattenedValuesTensorType, op.getValues(), torchCstZero,
flattenedValuesTensorLastDim);
values = typeConverter->materializeTargetConversion(
rewriter, loc,
typeConverter->convertType(flattenedValuesTensor.getType()),
flattenedValuesTensor);
Value valuesDimsList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
valuesDims);
valuesType = rewriter.getType<Torch::ValueTensorType>(
valuesShape, valuesType.getOptionalDtype());
values =
rewriter.create<AtenViewOp>(loc, valuesType, values, valuesDimsList);
// `TMTensor::ScatterOp` expects indices of element type i32.
Value indices = convertTensorToDtype(
rewriter, loc, indexTensor,
indices = convertTensorToDtype(
rewriter, loc, indices,
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
input = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(input.getType()), input);
values = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(values.getType()), values);
indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
@ -931,7 +828,8 @@ public:
// 3.) `input` is mapped to `original` in scatter op.
bool invalidInputTypeFound = false;
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, values, indices, input, /*uniqueIndices=*/false,
rewriter, loc, values, indices, input, indicesMap,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value valuesElement,
Value inputElement) {
Value yieldValue = valuesElement;
@ -1150,6 +1048,7 @@ public:
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, /*updates=*/gradOutputFlattened,
/*indices=*/indicesCollapsed, /*original=*/outputTensor,
/*dimensionsMap=*/createDefaultDimMap(indicesCollapsed),
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value valuesElement,
Value inputElement) {
@ -1292,6 +1191,7 @@ public:
srcType.getElementType(), /*init_element=*/normalizationValue);
self = createTMTensorScatterOp(
rewriter, loc, normalizations, indices, self,
/*dimensionsMap=*/createDefaultDimMap(indices),
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
b.create<TMTensor::YieldOp>(loc, update);
@ -1299,6 +1199,7 @@ public:
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
counts = createTMTensorScatterOp(
rewriter, loc, normalizations, indices, counts,
/*dimensionsMap=*/createDefaultDimMap(indices),
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
b.create<TMTensor::YieldOp>(loc, update);
@ -1309,7 +1210,7 @@ public:
// Create final operation
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, updates, indices, self,
/*uniqueIndices=*/false,
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
Value result;
if (reduceEnum == torch_upstream::ReductionType::SUM ||
@ -1353,6 +1254,7 @@ public:
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
counts = createTMTensorScatterOp(
rewriter, loc, updates, indices, counts,
/*dimensionsMap=*/createDefaultDimMap(indices),
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
Value result;

View File

@ -509,12 +509,32 @@ LogicalResult ScanOp::fold(FoldAdaptor adaptor,
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
static Type getComplexElementTypeOrSelf(Type ty) {
if (auto complex = dyn_cast_or_null<ComplexType>(ty))
return complex.getElementType();
return ty;
}
static bool isInvalid(ArrayRef<int64_t> dimsPos, int64_t rank) {
// early exit.
if (static_cast<int64_t>(dimsPos.size()) > rank)
return true;
DenseSet<int64_t> uniqued;
for (int64_t dim : dimsPos)
uniqued.insert(dim);
if (static_cast<int64_t>(dimsPos.size()) != uniqued.size())
return true;
return llvm::any_of(
dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; });
}
LogicalResult ScatterOp::verify() {
Operation *op = getOperation();
if (getInputs().size() != 2) {
return emitOpError("expected two input operands");
return op->emitOpError("expected two input operands");
}
if (getOutputs().size() != 1) {
return emitOpError("expected one output operand");
return op->emitOpError("expected one output operand");
}
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
return t1.getShape()[dim] == t2.getShape()[dim];
@ -526,10 +546,19 @@ LogicalResult ScatterOp::verify() {
return emitOpError("expected indices to be of rank 2 of i32 element type");
}
auto indexDepth = getIndexDepth();
if (indexDepth == ShapedType::kDynamic) {
if (ShapedType::isDynamic(indexDepth)) {
return emitOpError("expected index depth is static");
}
ArrayRef<int64_t> dimMap = getDimensionMap();
if (static_cast<int64_t>(dimMap.size()) != indexDepth) {
return op->emitOpError("invalid number of dimension map entries ");
}
auto originalType = getOriginalType();
if (isInvalid(dimMap, originalType.getRank()))
return op->emitOpError("dimension map is invalid");
// The first dimension of the indices should match the first dimension of the
// output. They indicate to the number of updates.
auto updateType = getUpdateType();
@ -540,7 +569,6 @@ LogicalResult ScatterOp::verify() {
return emitOpError(
"mismatch in shape of indices and update value at dim#0");
}
auto originalType = getOriginalType();
if (updateType.getRank() - 1 > originalType.getRank()) {
return emitOpError(
"update value rank exceeds the rank of the original value");
@ -553,7 +581,7 @@ LogicalResult ScatterOp::verify() {
"index depth and update value does not cover rank of original value");
}
// Validate the non-indexed update dims covier the full slice size of the
// Validate the non-indexed update dims cover the full slice size of the
// original tensor.
int64_t fullSliceDims = originalType.getRank() - indexDepth;
for (auto it :
@ -562,10 +590,11 @@ LogicalResult ScatterOp::verify() {
updateType.getRank()))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) !=
originalType.getDimSize(originalDim)) {
return emitOpError("mismatch in shape of update value dim#")
<< updateDim << " and original value at dim#" << originalDim;
if (!originalType.isDynamicDim(originalDim) &&
updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return op->emitOpError("shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim;
}
}
@ -576,23 +605,25 @@ LogicalResult ScatterOp::verify() {
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return emitOpError("indexed shape of update value dim#")
if (!originalType.isDynamicDim(originalDim) &&
updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return op->emitOpError("indexed shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim
<< " " << updateType.getDimSize(updateDim) << " "
<< originalType.getDimSize(originalDim);
}
}
Region &thisRegion = getRegion();
Block *body = &thisRegion.front();
Region &region = this->getRegion();
Block *body = &region.front();
if (body->getNumArguments() != 2) {
return emitOpError("expected region to have two arguments");
return op->emitOpError("expected region to have two arguments");
}
Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType();
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() ||
!getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) {
return emitOpError(
"expected region to have scalar argument of integer or float types");
}
@ -684,14 +715,16 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
starts[it.index() + offset] = it.value();
}
ArrayRef<int64_t> dimMap = getDimensionMap();
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
if (starts[i])
cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
starts[i] = cast;
auto dim = dimMap[i];
if (starts[dim])
ret = b.create<arith::AddIOp>(loc, ret, starts[dim]);
starts[dim] = ret;
}
Value init = b.create<memref::LoadOp>(loc, original(), starts);

View File

@ -75,6 +75,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
// (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToArith.
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());

View File

@ -1166,3 +1166,4 @@ class IndexPutImplIndexWithNoneModule(torch.nn.Module):
module_factory=lambda: IndexPutImplIndexWithNoneModule())
def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7))

View File

@ -64,7 +64,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
@ -74,7 +74,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
func.func @scatter_update_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true)
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
outs(%original : tensor<8xi32>) {
^bb0(%update: i32, %orig: i32): // no predecessors
@ -92,7 +92,7 @@ func.func @scatter_update_scalar_1D(
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
// CHECK: %[[CST1:.*]] = arith.constant 1 : i32
@ -104,7 +104,7 @@ func.func @scatter_update_scalar_1D(
func.func @scatter_add_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true)
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
outs(%original : tensor<8xi32>) {
^bb0(%update: i32, %orig: i32): // no predecessors

View File

@ -105,7 +105,7 @@ func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
func.func @scatter_update_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -131,7 +131,7 @@ func.func @scatter_update_scalar_1D(
func.func @scatter_add_scalar_2D(
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0, 1>} unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -162,7 +162,7 @@ func.func @scatter_add_scalar_2D(
func.func @scatter_update_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -192,7 +192,7 @@ func.func @scatter_update_slice_2D(
func.func @scatter_add_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -221,7 +221,7 @@ func.func @scatter_add_scalar_1D(
func.func @scatter_add_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -251,7 +251,7 @@ func.func @scatter_add_slice_2D(
func.func @scatter_update_scalar_dynamic_1D(
%original: memref<?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
outs(%original : memref<?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -277,7 +277,7 @@ func.func @scatter_update_scalar_dynamic_1D(
func.func @scatter_add_scalar_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0, 1>} unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -308,7 +308,7 @@ func.func @scatter_add_scalar_dynamic_2D(
func.func @scatter_update_slice_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?x?xi32>) {
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
@ -335,6 +335,7 @@ func.func @scatter_update_slice_dynamic_2D(
func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
tm_tensor.scatter
{dimension_map= array<i64: 0, 1, 2>}
unique_indices(true)
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
outs(%arg0 : memref<2x64x12xf32>) {

View File

@ -4,7 +4,7 @@ func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0, 1, 2>} unique_indices(true)
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -20,7 +20,7 @@ func.func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0, 1, 2>} unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, memref<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -36,7 +36,7 @@ func.func @scatter_extra_outputs(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// expected-error @+1 {{expected number of outputs to be same as the number of results}}
%0, %1 = tm_tensor.scatter unique_indices(true)
%0, %1 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -52,7 +52,7 @@ func.func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
outs(%original : memref<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -68,7 +68,7 @@ func.func @scatter_output_type_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<4x?xf32> {
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'tensor<4x?xf32>'}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -84,7 +84,7 @@ func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
outs(%original : memref<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -100,7 +100,7 @@ func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
tm_tensor.scatter unique_indices(true)
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -116,7 +116,7 @@ func.func @scatter_dim_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<48x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -132,7 +132,7 @@ func.func @scatter_dim_mismatch(
%update : tensor<64x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -148,7 +148,7 @@ func.func @scatter_dim_mismatch(
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{op update value rank exceeds the rank of the original value}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
@ -162,16 +162,16 @@ func.func @scatter_dim_mismatch(
func.func @scatter_dim_mismatch(
%update : tensor<?x4xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}}
%0 = tm_tensor.scatter unique_indices(true)
%original : tensor<?x3xf32>) -> tensor<?x3xf32> {
// expected-error @+1 {{shape of update value dim#1 exceeds original value at dim#1}}
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x4xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
outs(%original : tensor<?x3xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
} -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}
// -----
@ -180,7 +180,7 @@ func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{expected region to have scalar argument of integer or float types}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi32>) {
^bb0(%arg1: index, %arg2: index):
@ -197,7 +197,7 @@ func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi32>) {
^bb0(%arg1: i64, %arg2: i32):
@ -214,7 +214,7 @@ func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi32>) {
^bb0(%arg1: i32, %arg2: i64):
@ -231,7 +231,7 @@ func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i32, %arg2: i64):
@ -248,7 +248,7 @@ func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected region to have two arguments}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3 : i64):
@ -264,7 +264,7 @@ func.func @scatter_region_type_mismatch(
func.func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
@ -281,7 +281,7 @@ func.func @scatter_yield_mismatch(
func.func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
@ -299,7 +299,7 @@ func.func @scatter_index_depth_dynamic(
%update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected index depth is static}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?x?xi64>, tensor<?x?xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
@ -316,7 +316,7 @@ func.func @scatter_original_rank_mismatch(
%update : tensor<?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{op index depth and update value does not cover rank of original value}}
%0 = tm_tensor.scatter unique_indices(true)
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
ins(%update, %indices : tensor<?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):