From e30a083affb65c301066eda3df7112c06f4291da Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 27 Feb 2024 11:46:57 -0800 Subject: [PATCH] [torch] Rework lowering to tm_tensor.scatter to stop serialization (#2940) We collapsed and broadcasted scatter indices to a single element version. We should instead upport `tm_tensor.scatter`s support for multiple indices and the implicitly broadcasted behavior. This avoids the serialization and materializing a needlessly large indices tensor. --- .../Dialect/TMTensor/IR/TMTensorOps.td | 1 + .../TorchToTMTensor/TorchToTMTensor.cpp | 670 ++++++++---------- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 73 +- .../TorchConversion/Transforms/Passes.cpp | 1 + .../torch_mlir_e2e_test/test_suite/scatter.py | 1 + test/Dialect/TMTensor/bufferize.mlir | 8 +- test/Dialect/TMTensor/convert_to_loops.mlir | 17 +- test/Dialect/TMTensor/invalid.mlir | 50 +- 8 files changed, 380 insertions(+), 441 deletions(-) diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index 50dc0c1a0..12a74faa4 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -137,6 +137,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", let arguments = (ins Variadic:$inputs, Variadic:$outputs, + DenseI64ArrayAttr:$dimension_map, DefaultValuedAttr:$unique_indices ); let results = (outs Variadic:$results); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 4aa82420c..ac6d731bf 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -200,15 +200,30 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, scatterInputsVector[indexType.getRank()]); } +static llvm::SmallVector createDefaultDimMap(Value indices) { + llvm::SmallVector dmap; + if (auto iTy = dyn_cast(indices.getType())) + dmap.resize(iTy.getSizes()[1]); + + if (auto iTy = dyn_cast(indices.getType())) + dmap.resize(iTy.getDimSize(1)); + + for (int i = 0, s = dmap.size(); i < s; ++i) + dmap[i] = i; + + return dmap; +} + static Value createTMTensorScatterOp( OpBuilder &b, Location loc, Value updates, Value indices, Value original, - bool uniqueIndices, + llvm::ArrayRef dimensionsMap, bool uniqueIndices, function_ref bodyBuild) { + auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap); auto originalTensorType = original.getType().cast(); Type originalElementType = originalTensorType.getElementType(); auto scatterOp = b.create( loc, originalTensorType, ValueRange{updates, indices}, - ValueRange{original}, uniqueIndices); + ValueRange{original}, dimensionsMapAttr, uniqueIndices); Region &scatterOpRegion = scatterOp.getRegion(); auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); @@ -334,7 +349,7 @@ public: src, dim); Value scatterOp = createTMTensorScatterOp( rewriter, loc, updates, indices, self, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { b.create(loc, updatesElement); @@ -455,7 +470,7 @@ public: Value scatterOp = createTMTensorScatterOp( rewriter, loc, updatesTensor, indices, bincountTensor, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value _, Value bincountElem) { Value add = b.create(loc, bincountElem, constantOne); b.create(loc, add); @@ -466,235 +481,200 @@ public: }; } // namespace -// """Create a map from each dimension of the input tensor to the -// subspace that dimension corresponds to in the result shape one gets -// from indexing the tensor with the optional index tensors. -// -// Note: Index tensors are first broadcasted to a common shape before -// creating the mapping. So the index of every index tensor will map to -// the same dimensions in the result shape. -// -// For example: -// indices = [None, None, torch.randint(4, (6, 1)), torch.randint(5, (7,))] -// indexBroadcastShapeValue = [6, 7] -// map = {0: [0], 1: [1], 2: [2, 3], 3: [2, 3]} -static SmallVector> -getInputShapeToOutputShapeMap(SmallVector optionalIndices, - SmallVector indexBroadcastShapeValue) { - SmallVector indices; - for (Value index : optionalIndices) { - if (!index.getType().isa()) - indices.push_back(index); - } +namespace { - unsigned broadcastRank = indexBroadcastShapeValue.size(); - unsigned numIndexTensors = indices.size(); - int64_t indexOfFirstIndexTensor = -1; - SmallVector> result; - - for (unsigned i = 0; i < optionalIndices.size(); i++) { - if (optionalIndices[i].getType().isa()) { - unsigned val = i; - if (indexOfFirstIndexTensor >= 0) - val += broadcastRank - numIndexTensors; - result.push_back({val}); - } else { - if (indexOfFirstIndexTensor < 0) - indexOfFirstIndexTensor = i; - SmallVector outputIndices; - for (unsigned j = indexOfFirstIndexTensor; - j < (indexOfFirstIndexTensor + broadcastRank); j++) - outputIndices.push_back(j); - result.push_back(outputIndices); - } - } - return result; -} - -static std::tuple, SmallVector> -getIndicesFinalShape(ConversionPatternRewriter &rewriter, Location loc, - Value input, SmallVector optionalIndices, - SmallVector inputShapeInt, - SmallVector inputShapeValue, - SmallVector indexBroadcastShapeInt, - SmallVector indexBroadcastShapeValue) { - SmallVector result; - SmallVector resultInt; - bool handledIndexTensorSpace = false; - - for (unsigned i = 0; i < inputShapeValue.size(); i++) { - if (optionalIndices[i].getType().isa()) { - result.push_back(inputShapeValue[i]); - resultInt.push_back(inputShapeInt[i]); - } else { - if (!handledIndexTensorSpace) { - handledIndexTensorSpace = true; - for (unsigned j = 0; j < indexBroadcastShapeValue.size(); j++) { - result.push_back(indexBroadcastShapeValue[j]); - resultInt.push_back(indexBroadcastShapeInt[j]); - } - } - } - } - return std::make_tuple(result, resultInt); -} - -static FailureOr -getScatterIndices(Aten_IndexPutImplOp op, ConversionPatternRewriter &rewriter, - Type indicesDtype, SmallVector optionalIndices, - SmallVector indexBroadcastShapeInt, - SmallVector indexBroadcastShapeValue) { - Location loc = op.getLoc(); - MLIRContext *context = op->getContext(); - Value input = op.getSelf(); - - SmallVector> shapeMap = - getInputShapeToOutputShapeMap(optionalIndices, indexBroadcastShapeValue); - - SmallVector inputShapeInt{ - input.getType().cast().getSizes()}; - int64_t inputRank = inputShapeInt.size(); - SmallVector inputShapeValue; - for (unsigned i = 0; i < inputShapeInt.size(); i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - inputShapeValue.push_back( - rewriter.createOrFold(loc, input, dim)); - } - - auto finalShapeResult = getIndicesFinalShape( - rewriter, loc, input, optionalIndices, inputShapeInt, inputShapeValue, - indexBroadcastShapeInt, indexBroadcastShapeValue); - SmallVector finalShapeValue = std::get<0>(finalShapeResult); - SmallVector finalShapeInt = std::get<1>(finalShapeResult); - - Value torchCstNone = rewriter.create(loc); +Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, + OpBuilder b) { + llvm::SmallVector indices(indicesRef); + // Declare commonly used constants up front: Value torchCstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + b.create(loc, b.getI64IntegerAttr(0)); Value torchCstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + b.create(loc, b.getI64IntegerAttr(1)); + Value torchCstNegOne = + b.create(loc, b.getI64IntegerAttr(-1)); - Value indexBroadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - indexBroadcastShapeValue); - - // Calculating index count. - int64_t indexCount = 1; - if (llvm::all_of(finalShapeInt, - [](int64_t shape) { return shape != kUnknownSize; })) { - for (int64_t i : finalShapeInt) - indexCount *= i; - } else { - indexCount = kUnknownSize; + // Determine the broadcast sizes and materialize missing implicit end + // dimensions: + int64_t indicesRank = 0; + for (auto index : indices) { + auto indexTy = cast(index.getType()); + int64_t rank = indexTy.getSizes().size(); + indicesRank = std::max(rank, indicesRank); } - Value indexCountValue = finalShapeValue[0]; - for (unsigned i = 1; i < finalShapeValue.size(); i++) - indexCountValue = - rewriter.create(loc, indexCountValue, finalShapeValue[i]); + auto maxDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return std::max(dim0, dim1); + }; - ValueTensorType flattenIndicesType = - ValueTensorType::get(context, llvm::ArrayRef(indexCount), indicesDtype); - Value flattenEndDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(finalShapeInt.size() - 1)); + llvm::SmallVector broadcastSizes(indicesRank, torchCstOne); + llvm::SmallVector broadcastShape(indicesRank, 0); + for (auto index : indices) { + auto indexTy = cast(index.getType()); + auto shape = indexTy.getSizes(); + int32_t rank = shape.size(); - SmallVector broadcastedIndices; - for (unsigned i = 0; i < optionalIndices.size(); i++) { - Value broadcastedIndexTensor; - if (optionalIndices[i].getType().isa()) { - Value torchCstDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - Value inputDim = rewriter.create(loc, input, torchCstDim); - ValueTensorType tensorType = ValueTensorType::get( - context, llvm::ArrayRef(inputShapeInt[i]), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, tensorType, /*start=*/torchCstZero, /*end=*/inputDim, - /*step=*/torchCstOne, - /*dtype=*/torchCstNone, - /*layout=*/torchCstNone, - /*device=*/torchCstNone, - /*pin_memory=*/torchCstNone); - } else { - ValueTensorType tensorType = ValueTensorType::get( - context, llvm::ArrayRef(indexBroadcastShapeInt), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, tensorType, optionalIndices[i], indexBroadcastShapeTorchList); + for (int32_t j = 0; j < rank; ++j) { + Value dim = b.create(loc, b.getI64IntegerAttr(j)); + auto sizeOp = b.create(loc, index, dim); + auto size = shape[j]; + + int32_t idx = broadcastShape.size() - rank + j; + broadcastSizes[idx] = + b.create(loc, sizeOp, broadcastSizes[idx]); + broadcastShape[idx] = maxDim(size, broadcastShape[idx]); } - - // spotlight_indices(final_shape, shape_map[i]): - // Turn all values in `final_shape` to `1` except for those with index in - // `indices`. - // for j in range(len(final_shape)): - // if j not in indices: - // final_shape[j] = 1 - // This is equivalent to unsqueezing the index tensor at the dimension `j` - // not in indices. - for (unsigned j = 0; j < finalShapeInt.size(); j++) { - if (llvm::find(shapeMap[i], j) == shapeMap[i].end()) { - Value unsqueezeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(j)); - auto unsqueezedInfo = - unsqueezeTensor(rewriter, op, broadcastedIndexTensor, - /*dim=*/unsqueezeDim); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure( - op, "cannot generate unsqueeze tensor op"); - } - broadcastedIndexTensor = *unsqueezedInfo; - } - } - - // Performing broadcast to final shape. - Value broadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - finalShapeValue); - ValueTensorType broadcastTensorType = ValueTensorType::get( - context, llvm::ArrayRef(finalShapeInt), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, broadcastTensorType, broadcastedIndexTensor, - broadcastShapeTorchList); - - // Flattening the tensor. - broadcastedIndexTensor = rewriter.create( - loc, flattenIndicesType, broadcastedIndexTensor, torchCstZero, - flattenEndDim); - - broadcastedIndices.push_back(broadcastedIndexTensor); } - // Stacking broadcasted indices. - Value scatterIndices; - // The operation torch.stack([a, b], dim=0) is decomposed into: - // torch.cat([a.unsqueeze(dim=0), b.unsqueeze(dim=0)], dim=0) - // Unsqueeze all tensors before concatenating. - SmallVector unsqueezedIndexTensors; - for (Value tensor : broadcastedIndices) { - auto unsqueezedInfo = - unsqueezeTensor(rewriter, op, tensor, /*dim=*/torchCstZero); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor op"); - } - unsqueezedIndexTensors.push_back(*unsqueezedInfo); + auto mulDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return dim0 * dim1; + }; + + int64_t scatterBatchCount = 1; + for (auto dim : broadcastShape) { + scatterBatchCount = mulDim(scatterBatchCount, dim); + } + + // Broadcast together and flatten to batch values: + Value broadcastSizeList = b.create( + loc, Torch::ListType::get(b.getType()), broadcastSizes); + for (Value &index : indices) { + auto indexTy = cast(index.getType()); + auto expandTy = b.getType( + broadcastShape, indexTy.getOptionalDtype()); + index = b.create(loc, expandTy, index, + broadcastSizeList); + + auto flattenTy = b.getType( + scatterBatchCount, indexTy.getOptionalDtype()); + index = b.create( + loc, flattenTy, index, torchCstZero, torchCstNegOne); + } + + // Unsqueeze so we have a 1 dim to concat along: + for (Value &tensor : indices) { + auto btt = cast(tensor.getType()); + if (!btt.hasSizes()) + return nullptr; + + llvm::SmallVector shape(btt.getSizes()); + shape.push_back(1); + + auto unsqueezeTy = b.getType(shape, btt.getDtype()); + Value unsqueezed = + b.create(loc, unsqueezeTy, tensor, torchCstOne); + tensor = unsqueezed; } BaseTensorType unsqueezedTensorType = - unsqueezedIndexTensors[0].getType().cast(); - Value concatIndicesTorchList = rewriter.create( - loc, Torch::ListType::get(unsqueezedTensorType), unsqueezedIndexTensors); - ValueTensorType concatIndicesType = ValueTensorType::get( - context, llvm::ArrayRef({inputRank, indexCount}), indicesDtype); - scatterIndices = rewriter.create( - loc, concatIndicesType, concatIndicesTorchList, torchCstZero); - - ValueTensorType transposedIndicesType = ValueTensorType::get( - context, llvm::ArrayRef({indexCount, inputRank}), indicesDtype); - scatterIndices = rewriter.create( - loc, transposedIndicesType, scatterIndices, torchCstZero, torchCstOne); - return scatterIndices; + indices[0].getType().cast(); + Value indicesTorchList = b.create( + loc, Torch::ListType::get(unsqueezedTensorType), indices); + llvm::SmallVector concatShape{ + unsqueezedTensorType.getSizes()[0], static_cast(indices.size())}; + ValueTensorType concatIndicesType = b.getType( + llvm::ArrayRef(concatShape), unsqueezedTensorType.getDtype()); + return b.create(loc, concatIndicesType, indicesTorchList, + torchCstOne); +} + +// Helper that collapses the batch dimensions together and moves it to the front +// of the array. +static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, + int64_t count, OpBuilder b) { + if (batch == 0 && count == 1) + return values; + + auto valuesTy = cast(values.getType()); + auto inShape = valuesTy.getSizes(); + + llvm::SmallVector outShape; + llvm::SmallVector outDims; + + // We need a length-1 dim at the start to transpose the batch to: + if (batch != 0) { + outDims.push_back(b.create(loc, 1)); + outShape.push_back(1); + } + + // Dimensions before the batch stay the same: + for (int i = 0; i <= batch; i++) { + auto k = b.create(loc, b.getI64IntegerAttr(i)); + auto dim = b.create(loc, values, k); + outDims.push_back(dim); + outShape.push_back(inShape[i]); + } + + auto mulI = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return dim0 * dim1; + }; + + // Determine the collapse size of the batch dimension: + for (int i = 1; i < count; i++) { + outShape.back() = mulI(outShape.back(), inShape[batch + i]); + + auto k = + b.create(loc, b.getI64IntegerAttr(batch + i)); + auto dim = b.create(loc, values, k); + outDims.back() = b.create(loc, dim, outDims.back()); + } + + // Add the dimensions after the batch dims: + for (int i = batch + count, s = inShape.size(); i < s; ++i) { + auto k = b.create(loc, b.getI64IntegerAttr(i)); + auto dim = b.create(loc, values, k); + outDims.push_back(dim); + outShape.push_back(inShape[i]); + } + + Value outDimsList = b.create( + loc, Torch::ListType::get(b.getType()), outDims); + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + values = b.create(loc, valuesTy, values, outDimsList); + + if (batch == 0) + return values; + + // Batch is already at the front, no need to transpose: + std::swap(outDims[0], outDims[batch + 1]); + std::swap(outShape[0], outShape[batch + 1]); + + Value dim0 = b.create(loc, b.getI64IntegerAttr(0)); + Value dimB = + b.create(loc, b.getI64IntegerAttr(batch + 1)); + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + values = + b.create(loc, valuesTy, values, dim0, dimB); + + outDims.clear(); + outShape.clear(); + auto transposeShape = valuesTy.getSizes(); + int64_t transposeRank = transposeShape.size(); + for (int i = 0; i < transposeRank; ++i) { + if (i == batch + 1) + continue; + Value k = b.create(loc, b.getI64IntegerAttr(i)); + outDims.push_back(b.create(loc, values, k)); + outShape.push_back(transposeShape[i]); + } + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + outDimsList = b.create( + loc, Torch::ListType::get(b.getType()), outDims); + return b.create(loc, valuesTy, values, outDimsList); } -namespace { class ConvertAten_IndexPutImplOp : public OpConversionPattern { public: @@ -706,11 +686,11 @@ public: return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); - Value input = adaptor.getSelf(); - Value values = adaptor.getValues(); - RankedTensorType inputType = input.getType().cast(); - RankedTensorType valuesType = values.getType().cast(); - int64_t inputRank = inputType.getRank(); + Value input = op.getSelf(); + Value values = op.getValues(); + auto inputType = cast(input.getType()); + auto valuesType = cast(values.getType()); + int64_t inputRank = inputType.getSizes().size(); auto valuesTensorType = op.getValues().getType().cast(); auto resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); @@ -737,190 +717,107 @@ public: op, "Expected accumulate to be constant bool."); // The element type of the `input` and `values` should be same. - if (inputType.getElementType() != valuesType.getElementType()) + if (inputType.getDtype() != valuesType.getDtype()) return rewriter.notifyMatchFailure( op, "Input element type should be same as the values element type."); SmallVector optionalIndicesList; getListConstructElements(op.getIndices(), optionalIndicesList); + int64_t optionalIndicesCount = optionalIndicesList.size(); // The size of the list of the index tensors should not be greater than the // input rank. - if ((int64_t)optionalIndicesList.size() > inputRank) + if (optionalIndicesCount > inputRank) return rewriter.notifyMatchFailure( op, "Indices list size should not be greater than the input rank."); - Value torchCstNone = rewriter.create(loc); - unsigned sizeOptionalIndicesList = optionalIndicesList.size(); - SmallVector nonNoneIndexTensorDim; - unsigned numNonNoneIndices; - - if (sizeOptionalIndicesList == 0) + if (optionalIndicesCount == 0) return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); - for (unsigned i = 0; i < optionalIndicesList.size(); i++) { - if (!optionalIndicesList[i].getType().isa()) { - nonNoneIndexTensorDim.push_back(i); - } + // Filter to available indices and get the indicesMap: + SmallVector indicesList; + SmallVector indicesMap; + int64_t numBatchDims = 0; + for (int i = 0, s = optionalIndicesList.size(); i < s; ++i) { + if (isa(optionalIndicesList[i].getType())) + continue; + indicesList.push_back(optionalIndicesList[i]); + indicesMap.push_back(i); + + auto indexTy = cast(indicesList.back().getType()); + numBatchDims = std::max(static_cast(indexTy.getSizes().size()), + numBatchDims); } - numNonNoneIndices = nonNoneIndexTensorDim.size(); - if (numNonNoneIndices > 2) { - return rewriter.notifyMatchFailure( - op, "unimplemented: non none index tensors less than or equal to 2 " - "supported only"); - } else if (numNonNoneIndices == 2 && - nonNoneIndexTensorDim[0] != nonNoneIndexTensorDim[1] - 1) { - return rewriter.notifyMatchFailure( - op, "unimplemented: case of 2 non none index tensors is supported " - "only when both the tensors are along consecutive dimensions"); - } + // Value broadcasting semantics require batch dimensions to be up front if + // the indices are not sequential, otherwise they are sequentially at their + // location: + int64_t batchDim = 0; + for (int s = optionalIndicesList.size(); batchDim < s; ++batchDim) + if (!isa(optionalIndicesList[batchDim].getType())) + break; - // Padding the indices list with none values. - if (sizeOptionalIndicesList < inputRank) { - for (unsigned i = 0; i < (inputRank - sizeOptionalIndicesList); i++) - optionalIndicesList.push_back(torchCstNone); - } + int64_t nextNone = batchDim; + for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone) + if (isa(optionalIndicesList[nextNone].getType())) + break; - SmallVector indexBroadcastShapeInt{ - optionalIndicesList[nonNoneIndexTensorDim[0]] - .getType() - .cast() - .getSizes()}; - SmallVector indexBroadcastShapeValue; - if (numNonNoneIndices == 2) { - computeBroadcastShape(rewriter, loc, - optionalIndicesList[nonNoneIndexTensorDim[0]], - optionalIndicesList[nonNoneIndexTensorDim[1]], - indexBroadcastShapeInt, indexBroadcastShapeValue); - } else { - // It means there's only one index tensor and broadcast shape is same as - // that index tensor' shape. - for (unsigned i = 0; i < indexBroadcastShapeInt.size(); i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - indexBroadcastShapeValue.push_back(rewriter.createOrFold( - loc, optionalIndicesList[nonNoneIndexTensorDim[0]], dim)); - } - } + for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone) + if (!isa(optionalIndicesList[nextNone].getType())) + batchDim = 0; - Type indicesDtype = optionalIndicesList[nonNoneIndexTensorDim[0]] - .getType() - .cast() - .getDtype(); + // Indices are extended, catted, and collapsed into a [batch, depth] tensor: + Value indices = combinePutIndices(loc, indicesList, rewriter); - // This implementation is done to get the scatter indices: + // Bove batch dimensions to the front and collapse into a single dim: + values = + collapseAndMoveBatchDims(loc, values, batchDim, numBatchDims, rewriter); + valuesType = cast(values.getType()); - // def get_broadcast_shape(tensors): - // return list(torch.broadcast_tensors(*tensors)[0].shape) - - // def get_input_shape_to_output_shape_map(optional_index_tensors: - // list[Optional[torch.Tensor]]): - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) broadcast_rank = - // len(get_broadcast_shape(index_tensors)) num_of_index_tensors = - // len(index_tensors) index_of_first_index_tensor: Optional[int] = None - // result = {} - // for i, index in enumerate(optional_index_tensors): - // if index is None: - // val = i - // if index_of_first_index_tensor is not None: - // val += broadcast_rank - num_of_index_tensors - // result[i] = [val] - // else: - // if index_of_first_index_tensor is None: - // index_of_first_index_tensor = i - // output_indices = list(range(index_of_first_index_tensor, - // index_of_first_index_tensor + - // broadcast_rank)) - // result[i] = output_indices - // return result - - // def spotlight_indices(shape, indices: list[int]): - // """Turn all values in `shape` to `1` except for those with index in - // `indices`.""" shape = shape.copy() for i in range(len(shape)): - // if i not in indices: - // shape[i] = 1 - // return shape - - // def get_final_shape(input, optional_index_tensors: - // list[Optional[torch.Tensor]]): - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) index_tensors_broadcast_shape = - // get_broadcast_shape(index_tensors) result = [] - // handled_index_tensor_space = False - // for e, i in enumerate(input.shape): - // if optional_index_tensors[e] is None: - // result.append(i) - // else: - // if not handled_index_tensor_space: - // handled_index_tensor_space = True - // result += index_tensors_broadcast_shape - // return result - - // def get_scatter_indices(input, optional_index_tensors: - // list[Optional[torch.Tensor]]): - // assert len(input.size()) == len(optional_index_tensors), "Pad indices - // with None" shape_map = - // get_input_shape_to_output_shape_map(optional_index_tensors) - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) index_tensors_broadcast_shape = - // get_broadcast_shape(index_tensors) final_shape = - // get_final_shape(input, optional_index_tensors) - - // broadcasted_index_tensors = [] - // for e, optional_index_tensor in enumerate(optional_index_tensors): - // if optional_index_tensor is None: - // tensor_to_broadcast = torch.arange(0, input.size(e)) - // else: - // tensor_to_broadcast = - // optional_index_tensor.broadcast_to(index_tensors_broadcast_shape) - - // broadcasted_index_tensor = \ - // tensor_to_broadcast.reshape(spotlight_indices(final_shape, shape_map[e]))\ - // .broadcast_to(final_shape)\ - // .flatten() - // broadcasted_index_tensors.append(broadcasted_index_tensor) - - // return torch.stack(broadcasted_index_tensors, dim=0).t() - - auto scatterIndicesInfo = - getScatterIndices(op, rewriter, indicesDtype, optionalIndicesList, - indexBroadcastShapeInt, indexBroadcastShapeValue); - if (failed(scatterIndicesInfo)) { - return rewriter.notifyMatchFailure( - op, "cannot generate scatter indices for index put op"); - } - Value indexTensor = *scatterIndicesInfo; - - // Flattening the values tensor. - Value torchCstZero = rewriter.create( + // Materialize out the length-1 dimensions: + Value zero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - Value flattenedValuesTensorLastDim = rewriter.create( - loc, - rewriter.getI64IntegerAttr(valuesTensorType.getSizes().size() - 1)); - SmallVector valuesShapeInt{valuesTensorType.getSizes()}; - int64_t valuesCount = 1; - if (llvm::all_of(valuesShapeInt, - [](int64_t shape) { return shape != kUnknownSize; })) { - for (int64_t i : valuesShapeInt) - valuesCount *= i; - } else { - valuesCount = kUnknownSize; + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + llvm::SmallVector valuesShape{valuesType.getSizes().front()}; + llvm::SmallVector valuesDims; + valuesDims.push_back( + rewriter.create(loc, values, zero)); + + int vDim = 1; + for (int i = 0, s = inputType.getSizes().size(); i < s; ++i) { + if (i < optionalIndicesCount && + !isa(optionalIndicesList[i].getType())) { + valuesDims.push_back(one); + valuesShape.push_back(1); + continue; + } + + Value k = rewriter.create( + loc, rewriter.getI64IntegerAttr(vDim)); + valuesDims.push_back( + rewriter.create(loc, values, k)); + valuesShape.push_back(inputType.getSizes()[i]); + vDim++; } - auto flattenedValuesTensorType = ValueTensorType::get( - context, llvm::ArrayRef(valuesCount), valuesTensorType.getDtype()); - Value flattenedValuesTensor = rewriter.create( - loc, flattenedValuesTensorType, op.getValues(), torchCstZero, - flattenedValuesTensorLastDim); - values = typeConverter->materializeTargetConversion( - rewriter, loc, - typeConverter->convertType(flattenedValuesTensor.getType()), - flattenedValuesTensor); + + Value valuesDimsList = rewriter.create( + loc, Torch::ListType::get(rewriter.getType()), + valuesDims); + + valuesType = rewriter.getType( + valuesShape, valuesType.getOptionalDtype()); + values = + rewriter.create(loc, valuesType, values, valuesDimsList); // `TMTensor::ScatterOp` expects indices of element type i32. - Value indices = convertTensorToDtype( - rewriter, loc, indexTensor, + indices = convertTensorToDtype( + rewriter, loc, indices, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); + + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(input.getType()), input); + values = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(values.getType()), values); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); @@ -931,7 +828,8 @@ public: // 3.) `input` is mapped to `original` in scatter op. bool invalidInputTypeFound = false; Value scatterOp = createTMTensorScatterOp( - rewriter, loc, values, indices, input, /*uniqueIndices=*/false, + rewriter, loc, values, indices, input, indicesMap, + /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; @@ -1150,6 +1048,7 @@ public: Value scatterOp = createTMTensorScatterOp( rewriter, loc, /*updates=*/gradOutputFlattened, /*indices=*/indicesCollapsed, /*original=*/outputTensor, + /*dimensionsMap=*/createDefaultDimMap(indicesCollapsed), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { @@ -1292,6 +1191,7 @@ public: srcType.getElementType(), /*init_element=*/normalizationValue); self = createTMTensorScatterOp( rewriter, loc, normalizations, indices, self, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { b.create(loc, update); @@ -1299,6 +1199,7 @@ public: if (reduceEnum == torch_upstream::ReductionType::MEAN) { counts = createTMTensorScatterOp( rewriter, loc, normalizations, indices, counts, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { b.create(loc, update); @@ -1309,7 +1210,7 @@ public: // Create final operation Value scatterOp = createTMTensorScatterOp( rewriter, loc, updates, indices, self, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { Value result; if (reduceEnum == torch_upstream::ReductionType::SUM || @@ -1353,6 +1254,7 @@ public: if (reduceEnum == torch_upstream::ReductionType::MEAN) { counts = createTMTensorScatterOp( rewriter, loc, updates, indices, counts, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { Value result; diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 0b827893c..7b8a17682 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -509,12 +509,32 @@ LogicalResult ScanOp::fold(FoldAdaptor adaptor, //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// +static Type getComplexElementTypeOrSelf(Type ty) { + if (auto complex = dyn_cast_or_null(ty)) + return complex.getElementType(); + return ty; +} + +static bool isInvalid(ArrayRef dimsPos, int64_t rank) { + // early exit. + if (static_cast(dimsPos.size()) > rank) + return true; + DenseSet uniqued; + for (int64_t dim : dimsPos) + uniqued.insert(dim); + if (static_cast(dimsPos.size()) != uniqued.size()) + return true; + return llvm::any_of( + dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; }); +} + LogicalResult ScatterOp::verify() { + Operation *op = getOperation(); if (getInputs().size() != 2) { - return emitOpError("expected two input operands"); + return op->emitOpError("expected two input operands"); } if (getOutputs().size() != 1) { - return emitOpError("expected one output operand"); + return op->emitOpError("expected one output operand"); } auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { return t1.getShape()[dim] == t2.getShape()[dim]; @@ -526,10 +546,19 @@ LogicalResult ScatterOp::verify() { return emitOpError("expected indices to be of rank 2 of i32 element type"); } auto indexDepth = getIndexDepth(); - if (indexDepth == ShapedType::kDynamic) { + if (ShapedType::isDynamic(indexDepth)) { return emitOpError("expected index depth is static"); } + ArrayRef dimMap = getDimensionMap(); + if (static_cast(dimMap.size()) != indexDepth) { + return op->emitOpError("invalid number of dimension map entries "); + } + + auto originalType = getOriginalType(); + if (isInvalid(dimMap, originalType.getRank())) + return op->emitOpError("dimension map is invalid"); + // The first dimension of the indices should match the first dimension of the // output. They indicate to the number of updates. auto updateType = getUpdateType(); @@ -540,7 +569,6 @@ LogicalResult ScatterOp::verify() { return emitOpError( "mismatch in shape of indices and update value at dim#0"); } - auto originalType = getOriginalType(); if (updateType.getRank() - 1 > originalType.getRank()) { return emitOpError( "update value rank exceeds the rank of the original value"); @@ -553,7 +581,7 @@ LogicalResult ScatterOp::verify() { "index depth and update value does not cover rank of original value"); } - // Validate the non-indexed update dims covier the full slice size of the + // Validate the non-indexed update dims cover the full slice size of the // original tensor. int64_t fullSliceDims = originalType.getRank() - indexDepth; for (auto it : @@ -562,10 +590,11 @@ LogicalResult ScatterOp::verify() { updateType.getRank()))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) != - originalType.getDimSize(originalDim)) { - return emitOpError("mismatch in shape of update value dim#") - << updateDim << " and original value at dim#" << originalDim; + if (!originalType.isDynamicDim(originalDim) && + updateType.getDimSize(updateDim) > + originalType.getDimSize(originalDim)) { + return op->emitOpError("shape of update value dim#") + << updateDim << " exceeds original value at dim#" << originalDim; } } @@ -576,23 +605,25 @@ LogicalResult ScatterOp::verify() { llvm::seq(1, updateType.getRank() - fullSliceDims))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) > - originalType.getDimSize(originalDim)) { - return emitOpError("indexed shape of update value dim#") + if (!originalType.isDynamicDim(originalDim) && + updateType.getDimSize(updateDim) > + originalType.getDimSize(originalDim)) { + return op->emitOpError("indexed shape of update value dim#") << updateDim << " exceeds original value at dim#" << originalDim << " " << updateType.getDimSize(updateDim) << " " << originalType.getDimSize(originalDim); } } - Region &thisRegion = getRegion(); - Block *body = &thisRegion.front(); + Region ®ion = this->getRegion(); + Block *body = ®ion.front(); if (body->getNumArguments() != 2) { - return emitOpError("expected region to have two arguments"); + return op->emitOpError("expected region to have two arguments"); } Type arg0Type = body->getArgument(0).getType(); Type arg1Type = body->getArgument(1).getType(); - if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { + if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() || + !getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) { return emitOpError( "expected region to have scalar argument of integer or float types"); } @@ -684,14 +715,16 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, starts[it.index() + offset] = it.value(); } + ArrayRef dimMap = getDimensionMap(); for (auto i : llvm::seq(0, indexDepth)) { loadIndices.back() = b.create(loc, i); Value idx = b.create(loc, indices(), loadIndices); - Value cast = b.create(loc, b.getIndexType(), idx); + Value ret = b.create(loc, b.getIndexType(), idx); - if (starts[i]) - cast = b.create(loc, cast, starts[i]); - starts[i] = cast; + auto dim = dimMap[i]; + if (starts[dim]) + ret = b.create(loc, ret, starts[dim]); + starts[dim] = ret; } Value init = b.create(loc, original(), starts); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 9ff447371..7ac95ab6c 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -75,6 +75,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( // (e.g. dimensions which must be constant in a ranked programming model) // and those constants get somewhat obscured by TorchToArith. pm.addNestedPass(createConvertTorchToTMTensorPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index 8a84961a5..89b8b10eb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1166,3 +1166,4 @@ class IndexPutImplIndexWithNoneModule(torch.nn.Module): module_factory=lambda: IndexPutImplIndexWithNoneModule()) def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7)) + diff --git a/test/Dialect/TMTensor/bufferize.mlir b/test/Dialect/TMTensor/bufferize.mlir index 3e60814fa..f36a2f521 100644 --- a/test/Dialect/TMTensor/bufferize.mlir +++ b/test/Dialect/TMTensor/bufferize.mlir @@ -64,7 +64,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> -// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32 @@ -74,7 +74,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: func.func @scatter_update_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, %updates: tensor<3xi32>) -> tensor<8xi32> { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) outs(%original : tensor<8xi32>) { ^bb0(%update: i32, %orig: i32): // no predecessors @@ -92,7 +92,7 @@ func.func @scatter_update_scalar_1D( // CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> -// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: %[[CST1:.*]] = arith.constant 1 : i32 @@ -104,7 +104,7 @@ func.func @scatter_update_scalar_1D( func.func @scatter_add_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, %updates: tensor<3xi32>) -> tensor<8xi32> { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) outs(%original : tensor<8xi32>) { ^bb0(%update: i32, %orig: i32): // no predecessors diff --git a/test/Dialect/TMTensor/convert_to_loops.mlir b/test/Dialect/TMTensor/convert_to_loops.mlir index e9c160f99..7901cf505 100644 --- a/test/Dialect/TMTensor/convert_to_loops.mlir +++ b/test/Dialect/TMTensor/convert_to_loops.mlir @@ -105,7 +105,7 @@ func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { func.func @scatter_update_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -131,7 +131,7 @@ func.func @scatter_update_scalar_1D( func.func @scatter_add_scalar_2D( %original: memref<4x3xi32>, %indices: memref<3x2xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -162,7 +162,7 @@ func.func @scatter_add_scalar_2D( func.func @scatter_update_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -192,7 +192,7 @@ func.func @scatter_update_slice_2D( func.func @scatter_add_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -221,7 +221,7 @@ func.func @scatter_add_scalar_1D( func.func @scatter_add_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -251,7 +251,7 @@ func.func @scatter_add_slice_2D( func.func @scatter_update_scalar_dynamic_1D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -277,7 +277,7 @@ func.func @scatter_update_scalar_dynamic_1D( func.func @scatter_add_scalar_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -308,7 +308,7 @@ func.func @scatter_add_scalar_dynamic_2D( func.func @scatter_update_slice_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -335,6 +335,7 @@ func.func @scatter_update_slice_dynamic_2D( func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) { tm_tensor.scatter + {dimension_map= array} unique_indices(true) ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>) outs(%arg0 : memref<2x64x12xf32>) { diff --git a/test/Dialect/TMTensor/invalid.mlir b/test/Dialect/TMTensor/invalid.mlir index bfcd1adb8..6653d944a 100644 --- a/test/Dialect/TMTensor/invalid.mlir +++ b/test/Dialect/TMTensor/invalid.mlir @@ -4,7 +4,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -20,7 +20,7 @@ func.func @scatter_mixed_tensor_memref( %update : tensor, %indices : memref, %original : tensor) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, memref) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -36,7 +36,7 @@ func.func @scatter_extra_outputs( %update : tensor, %indices : tensor, %original : tensor) -> (tensor, tensor) { // expected-error @+1 {{expected number of outputs to be same as the number of results}} - %0, %1 = tm_tensor.scatter unique_indices(true) + %0, %1 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -52,7 +52,7 @@ func.func @scatter_mixed_tensor_memref( %update : tensor, %indices : tensor, %original : memref) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : memref) { ^bb0(%arg1: f32, %arg2: f32): @@ -68,7 +68,7 @@ func.func @scatter_output_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor<4x?xf32> { // expected-error @+1 {{expected type of `outs` operand #0 'tensor' to be same as result type 'tensor<4x?xf32>'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -84,7 +84,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : tensor, %original : memref) { // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, tensor) outs(%original : memref) { ^bb0(%arg1: f32, %arg2: f32): @@ -100,7 +100,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : memref, %original : tensor) { // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, memref) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -116,7 +116,7 @@ func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor<48x1xi32>) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -132,7 +132,7 @@ func.func @scatter_dim_mismatch( %update : tensor<64x?xf32>, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -148,7 +148,7 @@ func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{op update value rank exceeds the rank of the original value}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -162,16 +162,16 @@ func.func @scatter_dim_mismatch( func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}} - %0 = tm_tensor.scatter unique_indices(true) + %original : tensor) -> tensor { + // expected-error @+1 {{shape of update value dim#1 exceeds original value at dim#1}} + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { + outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): %1 = arith.addf %arg1, %arg2 : f32 tm_tensor.yield %1 : f32 - } -> tensor - return %0 : tensor + } -> tensor + return %0 : tensor } // ----- @@ -180,7 +180,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected region to have scalar argument of integer or float types}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: index, %arg2: index): @@ -197,7 +197,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i32): @@ -214,7 +214,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -231,7 +231,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -248,7 +248,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected region to have two arguments}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64): @@ -264,7 +264,7 @@ func.func @scatter_region_type_mismatch( func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -281,7 +281,7 @@ func.func @scatter_yield_mismatch( func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -299,7 +299,7 @@ func.func @scatter_index_depth_dynamic( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected index depth is static}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -316,7 +316,7 @@ func.func @scatter_original_rank_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{op index depth and update value does not cover rank of original value}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64):