mirror of https://github.com/llvm/torch-mlir
[torch] Rework lowering to tm_tensor.scatter to stop serialization (#2940)
We collapsed and broadcasted scatter indices to a single element version. We should instead upport `tm_tensor.scatter`s support for multiple indices and the implicitly broadcasted behavior. This avoids the serialization and materializing a needlessly large indices tensor.pull/2960/head
parent
d628b5fd06
commit
e30a083aff
|
@ -137,6 +137,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
|||
let arguments = (ins
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$outputs,
|
||||
DenseI64ArrayAttr:$dimension_map,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$unique_indices
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||
|
|
|
@ -200,15 +200,30 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
|
|||
scatterInputsVector[indexType.getRank()]);
|
||||
}
|
||||
|
||||
static llvm::SmallVector<int64_t> createDefaultDimMap(Value indices) {
|
||||
llvm::SmallVector<int64_t> dmap;
|
||||
if (auto iTy = dyn_cast<BaseTensorType>(indices.getType()))
|
||||
dmap.resize(iTy.getSizes()[1]);
|
||||
|
||||
if (auto iTy = dyn_cast<RankedTensorType>(indices.getType()))
|
||||
dmap.resize(iTy.getDimSize(1));
|
||||
|
||||
for (int i = 0, s = dmap.size(); i < s; ++i)
|
||||
dmap[i] = i;
|
||||
|
||||
return dmap;
|
||||
}
|
||||
|
||||
static Value createTMTensorScatterOp(
|
||||
OpBuilder &b, Location loc, Value updates, Value indices, Value original,
|
||||
bool uniqueIndices,
|
||||
llvm::ArrayRef<int64_t> dimensionsMap, bool uniqueIndices,
|
||||
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
||||
auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap);
|
||||
auto originalTensorType = original.getType().cast<RankedTensorType>();
|
||||
Type originalElementType = originalTensorType.getElementType();
|
||||
auto scatterOp = b.create<TMTensor::ScatterOp>(
|
||||
loc, originalTensorType, ValueRange{updates, indices},
|
||||
ValueRange{original}, uniqueIndices);
|
||||
ValueRange{original}, dimensionsMapAttr, uniqueIndices);
|
||||
|
||||
Region &scatterOpRegion = scatterOp.getRegion();
|
||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
||||
|
@ -334,7 +349,7 @@ public:
|
|||
src, dim);
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, updates, indices, self,
|
||||
/*uniqueIndices=*/false,
|
||||
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value updatesElement,
|
||||
Value inputElement) {
|
||||
b.create<TMTensor::YieldOp>(loc, updatesElement);
|
||||
|
@ -455,7 +470,7 @@ public:
|
|||
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, updatesTensor, indices, bincountTensor,
|
||||
/*uniqueIndices=*/false,
|
||||
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value _, Value bincountElem) {
|
||||
Value add = b.create<arith::AddIOp>(loc, bincountElem, constantOne);
|
||||
b.create<TMTensor::YieldOp>(loc, add);
|
||||
|
@ -466,235 +481,200 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// """Create a map from each dimension of the input tensor to the
|
||||
// subspace that dimension corresponds to in the result shape one gets
|
||||
// from indexing the tensor with the optional index tensors.
|
||||
//
|
||||
// Note: Index tensors are first broadcasted to a common shape before
|
||||
// creating the mapping. So the index of every index tensor will map to
|
||||
// the same dimensions in the result shape.
|
||||
//
|
||||
// For example:
|
||||
// indices = [None, None, torch.randint(4, (6, 1)), torch.randint(5, (7,))]
|
||||
// indexBroadcastShapeValue = [6, 7]
|
||||
// map = {0: [0], 1: [1], 2: [2, 3], 3: [2, 3]}
|
||||
static SmallVector<SmallVector<int64_t>>
|
||||
getInputShapeToOutputShapeMap(SmallVector<Value> optionalIndices,
|
||||
SmallVector<Value> indexBroadcastShapeValue) {
|
||||
SmallVector<Value> indices;
|
||||
for (Value index : optionalIndices) {
|
||||
if (!index.getType().isa<Torch::NoneType>())
|
||||
indices.push_back(index);
|
||||
}
|
||||
namespace {
|
||||
|
||||
unsigned broadcastRank = indexBroadcastShapeValue.size();
|
||||
unsigned numIndexTensors = indices.size();
|
||||
int64_t indexOfFirstIndexTensor = -1;
|
||||
SmallVector<SmallVector<int64_t>> result;
|
||||
|
||||
for (unsigned i = 0; i < optionalIndices.size(); i++) {
|
||||
if (optionalIndices[i].getType().isa<Torch::NoneType>()) {
|
||||
unsigned val = i;
|
||||
if (indexOfFirstIndexTensor >= 0)
|
||||
val += broadcastRank - numIndexTensors;
|
||||
result.push_back({val});
|
||||
} else {
|
||||
if (indexOfFirstIndexTensor < 0)
|
||||
indexOfFirstIndexTensor = i;
|
||||
SmallVector<int64_t> outputIndices;
|
||||
for (unsigned j = indexOfFirstIndexTensor;
|
||||
j < (indexOfFirstIndexTensor + broadcastRank); j++)
|
||||
outputIndices.push_back(j);
|
||||
result.push_back(outputIndices);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::tuple<SmallVector<Value>, SmallVector<int64_t>>
|
||||
getIndicesFinalShape(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value input, SmallVector<Value> optionalIndices,
|
||||
SmallVector<int64_t> inputShapeInt,
|
||||
SmallVector<Value> inputShapeValue,
|
||||
SmallVector<int64_t> indexBroadcastShapeInt,
|
||||
SmallVector<Value> indexBroadcastShapeValue) {
|
||||
SmallVector<Value> result;
|
||||
SmallVector<int64_t> resultInt;
|
||||
bool handledIndexTensorSpace = false;
|
||||
|
||||
for (unsigned i = 0; i < inputShapeValue.size(); i++) {
|
||||
if (optionalIndices[i].getType().isa<Torch::NoneType>()) {
|
||||
result.push_back(inputShapeValue[i]);
|
||||
resultInt.push_back(inputShapeInt[i]);
|
||||
} else {
|
||||
if (!handledIndexTensorSpace) {
|
||||
handledIndexTensorSpace = true;
|
||||
for (unsigned j = 0; j < indexBroadcastShapeValue.size(); j++) {
|
||||
result.push_back(indexBroadcastShapeValue[j]);
|
||||
resultInt.push_back(indexBroadcastShapeInt[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(result, resultInt);
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
getScatterIndices(Aten_IndexPutImplOp op, ConversionPatternRewriter &rewriter,
|
||||
Type indicesDtype, SmallVector<Value> optionalIndices,
|
||||
SmallVector<int64_t> indexBroadcastShapeInt,
|
||||
SmallVector<Value> indexBroadcastShapeValue) {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = op.getSelf();
|
||||
|
||||
SmallVector<SmallVector<int64_t>> shapeMap =
|
||||
getInputShapeToOutputShapeMap(optionalIndices, indexBroadcastShapeValue);
|
||||
|
||||
SmallVector<int64_t> inputShapeInt{
|
||||
input.getType().cast<BaseTensorType>().getSizes()};
|
||||
int64_t inputRank = inputShapeInt.size();
|
||||
SmallVector<Value> inputShapeValue;
|
||||
for (unsigned i = 0; i < inputShapeInt.size(); i++) {
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
inputShapeValue.push_back(
|
||||
rewriter.createOrFold<AtenSizeIntOp>(loc, input, dim));
|
||||
}
|
||||
|
||||
auto finalShapeResult = getIndicesFinalShape(
|
||||
rewriter, loc, input, optionalIndices, inputShapeInt, inputShapeValue,
|
||||
indexBroadcastShapeInt, indexBroadcastShapeValue);
|
||||
SmallVector<Value> finalShapeValue = std::get<0>(finalShapeResult);
|
||||
SmallVector<int64_t> finalShapeInt = std::get<1>(finalShapeResult);
|
||||
|
||||
Value torchCstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
|
||||
OpBuilder b) {
|
||||
llvm::SmallVector<Value> indices(indicesRef);
|
||||
// Declare commonly used constants up front:
|
||||
Value torchCstZero =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
|
||||
Value torchCstOne =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
|
||||
Value torchCstNegOne =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(-1));
|
||||
|
||||
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
indexBroadcastShapeValue);
|
||||
|
||||
// Calculating index count.
|
||||
int64_t indexCount = 1;
|
||||
if (llvm::all_of(finalShapeInt,
|
||||
[](int64_t shape) { return shape != kUnknownSize; })) {
|
||||
for (int64_t i : finalShapeInt)
|
||||
indexCount *= i;
|
||||
} else {
|
||||
indexCount = kUnknownSize;
|
||||
// Determine the broadcast sizes and materialize missing implicit end
|
||||
// dimensions:
|
||||
int64_t indicesRank = 0;
|
||||
for (auto index : indices) {
|
||||
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
|
||||
int64_t rank = indexTy.getSizes().size();
|
||||
indicesRank = std::max(rank, indicesRank);
|
||||
}
|
||||
|
||||
Value indexCountValue = finalShapeValue[0];
|
||||
for (unsigned i = 1; i < finalShapeValue.size(); i++)
|
||||
indexCountValue =
|
||||
rewriter.create<AtenMulIntOp>(loc, indexCountValue, finalShapeValue[i]);
|
||||
auto maxDim = [](int64_t dim0, int64_t dim1) {
|
||||
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
|
||||
return Torch::kUnknownSize;
|
||||
return std::max(dim0, dim1);
|
||||
};
|
||||
|
||||
ValueTensorType flattenIndicesType =
|
||||
ValueTensorType::get(context, llvm::ArrayRef(indexCount), indicesDtype);
|
||||
Value flattenEndDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(finalShapeInt.size() - 1));
|
||||
llvm::SmallVector<Value> broadcastSizes(indicesRank, torchCstOne);
|
||||
llvm::SmallVector<int64_t> broadcastShape(indicesRank, 0);
|
||||
for (auto index : indices) {
|
||||
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
|
||||
auto shape = indexTy.getSizes();
|
||||
int32_t rank = shape.size();
|
||||
|
||||
SmallVector<Value> broadcastedIndices;
|
||||
for (unsigned i = 0; i < optionalIndices.size(); i++) {
|
||||
Value broadcastedIndexTensor;
|
||||
if (optionalIndices[i].getType().isa<Torch::NoneType>()) {
|
||||
Value torchCstDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
Value inputDim = rewriter.create<AtenSizeIntOp>(loc, input, torchCstDim);
|
||||
ValueTensorType tensorType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef(inputShapeInt[i]), indicesDtype);
|
||||
broadcastedIndexTensor = rewriter.create<AtenArangeStartStepOp>(
|
||||
loc, tensorType, /*start=*/torchCstZero, /*end=*/inputDim,
|
||||
/*step=*/torchCstOne,
|
||||
/*dtype=*/torchCstNone,
|
||||
/*layout=*/torchCstNone,
|
||||
/*device=*/torchCstNone,
|
||||
/*pin_memory=*/torchCstNone);
|
||||
} else {
|
||||
ValueTensorType tensorType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef(indexBroadcastShapeInt), indicesDtype);
|
||||
broadcastedIndexTensor = rewriter.create<AtenBroadcastToOp>(
|
||||
loc, tensorType, optionalIndices[i], indexBroadcastShapeTorchList);
|
||||
for (int32_t j = 0; j < rank; ++j) {
|
||||
Value dim = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(j));
|
||||
auto sizeOp = b.create<Torch::AtenSizeIntOp>(loc, index, dim);
|
||||
auto size = shape[j];
|
||||
|
||||
int32_t idx = broadcastShape.size() - rank + j;
|
||||
broadcastSizes[idx] =
|
||||
b.create<Torch::PrimMaxIntOp>(loc, sizeOp, broadcastSizes[idx]);
|
||||
broadcastShape[idx] = maxDim(size, broadcastShape[idx]);
|
||||
}
|
||||
|
||||
// spotlight_indices(final_shape, shape_map[i]):
|
||||
// Turn all values in `final_shape` to `1` except for those with index in
|
||||
// `indices`.
|
||||
// for j in range(len(final_shape)):
|
||||
// if j not in indices:
|
||||
// final_shape[j] = 1
|
||||
// This is equivalent to unsqueezing the index tensor at the dimension `j`
|
||||
// not in indices.
|
||||
for (unsigned j = 0; j < finalShapeInt.size(); j++) {
|
||||
if (llvm::find(shapeMap[i], j) == shapeMap[i].end()) {
|
||||
Value unsqueezeDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(j));
|
||||
auto unsqueezedInfo =
|
||||
unsqueezeTensor(rewriter, op, broadcastedIndexTensor,
|
||||
/*dim=*/unsqueezeDim);
|
||||
if (failed(unsqueezedInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "cannot generate unsqueeze tensor op");
|
||||
}
|
||||
broadcastedIndexTensor = *unsqueezedInfo;
|
||||
}
|
||||
}
|
||||
|
||||
// Performing broadcast to final shape.
|
||||
Value broadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
finalShapeValue);
|
||||
ValueTensorType broadcastTensorType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef(finalShapeInt), indicesDtype);
|
||||
broadcastedIndexTensor = rewriter.create<AtenBroadcastToOp>(
|
||||
loc, broadcastTensorType, broadcastedIndexTensor,
|
||||
broadcastShapeTorchList);
|
||||
|
||||
// Flattening the tensor.
|
||||
broadcastedIndexTensor = rewriter.create<AtenFlattenUsingIntsOp>(
|
||||
loc, flattenIndicesType, broadcastedIndexTensor, torchCstZero,
|
||||
flattenEndDim);
|
||||
|
||||
broadcastedIndices.push_back(broadcastedIndexTensor);
|
||||
}
|
||||
|
||||
// Stacking broadcasted indices.
|
||||
Value scatterIndices;
|
||||
// The operation torch.stack([a, b], dim=0) is decomposed into:
|
||||
// torch.cat([a.unsqueeze(dim=0), b.unsqueeze(dim=0)], dim=0)
|
||||
// Unsqueeze all tensors before concatenating.
|
||||
SmallVector<Value> unsqueezedIndexTensors;
|
||||
for (Value tensor : broadcastedIndices) {
|
||||
auto unsqueezedInfo =
|
||||
unsqueezeTensor(rewriter, op, tensor, /*dim=*/torchCstZero);
|
||||
if (failed(unsqueezedInfo)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor op");
|
||||
}
|
||||
unsqueezedIndexTensors.push_back(*unsqueezedInfo);
|
||||
auto mulDim = [](int64_t dim0, int64_t dim1) {
|
||||
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
|
||||
return Torch::kUnknownSize;
|
||||
return dim0 * dim1;
|
||||
};
|
||||
|
||||
int64_t scatterBatchCount = 1;
|
||||
for (auto dim : broadcastShape) {
|
||||
scatterBatchCount = mulDim(scatterBatchCount, dim);
|
||||
}
|
||||
|
||||
// Broadcast together and flatten to batch values:
|
||||
Value broadcastSizeList = b.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(b.getType<Torch::IntType>()), broadcastSizes);
|
||||
for (Value &index : indices) {
|
||||
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
|
||||
auto expandTy = b.getType<Torch::ValueTensorType>(
|
||||
broadcastShape, indexTy.getOptionalDtype());
|
||||
index = b.create<Torch::AtenBroadcastToOp>(loc, expandTy, index,
|
||||
broadcastSizeList);
|
||||
|
||||
auto flattenTy = b.getType<Torch::ValueTensorType>(
|
||||
scatterBatchCount, indexTy.getOptionalDtype());
|
||||
index = b.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenTy, index, torchCstZero, torchCstNegOne);
|
||||
}
|
||||
|
||||
// Unsqueeze so we have a 1 dim to concat along:
|
||||
for (Value &tensor : indices) {
|
||||
auto btt = cast<Torch::BaseTensorType>(tensor.getType());
|
||||
if (!btt.hasSizes())
|
||||
return nullptr;
|
||||
|
||||
llvm::SmallVector<int64_t> shape(btt.getSizes());
|
||||
shape.push_back(1);
|
||||
|
||||
auto unsqueezeTy = b.getType<Torch::ValueTensorType>(shape, btt.getDtype());
|
||||
Value unsqueezed =
|
||||
b.create<AtenUnsqueezeOp>(loc, unsqueezeTy, tensor, torchCstOne);
|
||||
tensor = unsqueezed;
|
||||
}
|
||||
|
||||
BaseTensorType unsqueezedTensorType =
|
||||
unsqueezedIndexTensors[0].getType().cast<BaseTensorType>();
|
||||
Value concatIndicesTorchList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(unsqueezedTensorType), unsqueezedIndexTensors);
|
||||
ValueTensorType concatIndicesType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef({inputRank, indexCount}), indicesDtype);
|
||||
scatterIndices = rewriter.create<AtenCatOp>(
|
||||
loc, concatIndicesType, concatIndicesTorchList, torchCstZero);
|
||||
|
||||
ValueTensorType transposedIndicesType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef({indexCount, inputRank}), indicesDtype);
|
||||
scatterIndices = rewriter.create<AtenTransposeIntOp>(
|
||||
loc, transposedIndicesType, scatterIndices, torchCstZero, torchCstOne);
|
||||
return scatterIndices;
|
||||
indices[0].getType().cast<BaseTensorType>();
|
||||
Value indicesTorchList = b.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(unsqueezedTensorType), indices);
|
||||
llvm::SmallVector<int64_t, 2> concatShape{
|
||||
unsqueezedTensorType.getSizes()[0], static_cast<int64_t>(indices.size())};
|
||||
ValueTensorType concatIndicesType = b.getType<ValueTensorType>(
|
||||
llvm::ArrayRef(concatShape), unsqueezedTensorType.getDtype());
|
||||
return b.create<AtenCatOp>(loc, concatIndicesType, indicesTorchList,
|
||||
torchCstOne);
|
||||
}
|
||||
|
||||
// Helper that collapses the batch dimensions together and moves it to the front
|
||||
// of the array.
|
||||
static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
|
||||
int64_t count, OpBuilder b) {
|
||||
if (batch == 0 && count == 1)
|
||||
return values;
|
||||
|
||||
auto valuesTy = cast<Torch::ValueTensorType>(values.getType());
|
||||
auto inShape = valuesTy.getSizes();
|
||||
|
||||
llvm::SmallVector<int64_t> outShape;
|
||||
llvm::SmallVector<Value> outDims;
|
||||
|
||||
// We need a length-1 dim at the start to transpose the batch to:
|
||||
if (batch != 0) {
|
||||
outDims.push_back(b.create<Torch::ConstantIntOp>(loc, 1));
|
||||
outShape.push_back(1);
|
||||
}
|
||||
|
||||
// Dimensions before the batch stay the same:
|
||||
for (int i = 0; i <= batch; i++) {
|
||||
auto k = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
|
||||
auto dim = b.create<Torch::AtenSizeIntOp>(loc, values, k);
|
||||
outDims.push_back(dim);
|
||||
outShape.push_back(inShape[i]);
|
||||
}
|
||||
|
||||
auto mulI = [](int64_t dim0, int64_t dim1) {
|
||||
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
|
||||
return Torch::kUnknownSize;
|
||||
return dim0 * dim1;
|
||||
};
|
||||
|
||||
// Determine the collapse size of the batch dimension:
|
||||
for (int i = 1; i < count; i++) {
|
||||
outShape.back() = mulI(outShape.back(), inShape[batch + i]);
|
||||
|
||||
auto k =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(batch + i));
|
||||
auto dim = b.create<Torch::AtenSizeIntOp>(loc, values, k);
|
||||
outDims.back() = b.create<Torch::AtenMulIntOp>(loc, dim, outDims.back());
|
||||
}
|
||||
|
||||
// Add the dimensions after the batch dims:
|
||||
for (int i = batch + count, s = inShape.size(); i < s; ++i) {
|
||||
auto k = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
|
||||
auto dim = b.create<Torch::AtenSizeIntOp>(loc, values, k);
|
||||
outDims.push_back(dim);
|
||||
outShape.push_back(inShape[i]);
|
||||
}
|
||||
|
||||
Value outDimsList = b.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(b.getType<Torch::IntType>()), outDims);
|
||||
|
||||
valuesTy =
|
||||
b.getType<Torch::ValueTensorType>(outShape, valuesTy.getOptionalDtype());
|
||||
values = b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
|
||||
|
||||
if (batch == 0)
|
||||
return values;
|
||||
|
||||
// Batch is already at the front, no need to transpose:
|
||||
std::swap(outDims[0], outDims[batch + 1]);
|
||||
std::swap(outShape[0], outShape[batch + 1]);
|
||||
|
||||
Value dim0 = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
|
||||
Value dimB =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(batch + 1));
|
||||
|
||||
valuesTy =
|
||||
b.getType<Torch::ValueTensorType>(outShape, valuesTy.getOptionalDtype());
|
||||
values =
|
||||
b.create<Torch::AtenTransposeIntOp>(loc, valuesTy, values, dim0, dimB);
|
||||
|
||||
outDims.clear();
|
||||
outShape.clear();
|
||||
auto transposeShape = valuesTy.getSizes();
|
||||
int64_t transposeRank = transposeShape.size();
|
||||
for (int i = 0; i < transposeRank; ++i) {
|
||||
if (i == batch + 1)
|
||||
continue;
|
||||
Value k = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
|
||||
outDims.push_back(b.create<AtenSizeIntOp>(loc, values, k));
|
||||
outShape.push_back(transposeShape[i]);
|
||||
}
|
||||
|
||||
valuesTy =
|
||||
b.getType<Torch::ValueTensorType>(outShape, valuesTy.getOptionalDtype());
|
||||
outDimsList = b.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(b.getType<Torch::IntType>()), outDims);
|
||||
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAten_IndexPutImplOp
|
||||
: public OpConversionPattern<Aten_IndexPutImplOp> {
|
||||
public:
|
||||
|
@ -706,11 +686,11 @@ public:
|
|||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = adaptor.getSelf();
|
||||
Value values = adaptor.getValues();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
Value input = op.getSelf();
|
||||
Value values = op.getValues();
|
||||
auto inputType = cast<ValueTensorType>(input.getType());
|
||||
auto valuesType = cast<ValueTensorType>(values.getType());
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
auto valuesTensorType = op.getValues().getType().cast<BaseTensorType>();
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
@ -737,190 +717,107 @@ public:
|
|||
op, "Expected accumulate to be constant bool.");
|
||||
|
||||
// The element type of the `input` and `values` should be same.
|
||||
if (inputType.getElementType() != valuesType.getElementType())
|
||||
if (inputType.getDtype() != valuesType.getDtype())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input element type should be same as the values element type.");
|
||||
|
||||
SmallVector<Value> optionalIndicesList;
|
||||
getListConstructElements(op.getIndices(), optionalIndicesList);
|
||||
int64_t optionalIndicesCount = optionalIndicesList.size();
|
||||
// The size of the list of the index tensors should not be greater than the
|
||||
// input rank.
|
||||
if ((int64_t)optionalIndicesList.size() > inputRank)
|
||||
if (optionalIndicesCount > inputRank)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Indices list size should not be greater than the input rank.");
|
||||
|
||||
Value torchCstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
unsigned sizeOptionalIndicesList = optionalIndicesList.size();
|
||||
SmallVector<int64_t> nonNoneIndexTensorDim;
|
||||
unsigned numNonNoneIndices;
|
||||
|
||||
if (sizeOptionalIndicesList == 0)
|
||||
if (optionalIndicesCount == 0)
|
||||
return rewriter.notifyMatchFailure(op, "Indices list must not be empty.");
|
||||
|
||||
for (unsigned i = 0; i < optionalIndicesList.size(); i++) {
|
||||
if (!optionalIndicesList[i].getType().isa<Torch::NoneType>()) {
|
||||
nonNoneIndexTensorDim.push_back(i);
|
||||
}
|
||||
// Filter to available indices and get the indicesMap:
|
||||
SmallVector<Value> indicesList;
|
||||
SmallVector<int64_t> indicesMap;
|
||||
int64_t numBatchDims = 0;
|
||||
for (int i = 0, s = optionalIndicesList.size(); i < s; ++i) {
|
||||
if (isa<Torch::NoneType>(optionalIndicesList[i].getType()))
|
||||
continue;
|
||||
indicesList.push_back(optionalIndicesList[i]);
|
||||
indicesMap.push_back(i);
|
||||
|
||||
auto indexTy = cast<ValueTensorType>(indicesList.back().getType());
|
||||
numBatchDims = std::max(static_cast<int64_t>(indexTy.getSizes().size()),
|
||||
numBatchDims);
|
||||
}
|
||||
|
||||
numNonNoneIndices = nonNoneIndexTensorDim.size();
|
||||
if (numNonNoneIndices > 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non none index tensors less than or equal to 2 "
|
||||
"supported only");
|
||||
} else if (numNonNoneIndices == 2 &&
|
||||
nonNoneIndexTensorDim[0] != nonNoneIndexTensorDim[1] - 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: case of 2 non none index tensors is supported "
|
||||
"only when both the tensors are along consecutive dimensions");
|
||||
}
|
||||
// Value broadcasting semantics require batch dimensions to be up front if
|
||||
// the indices are not sequential, otherwise they are sequentially at their
|
||||
// location:
|
||||
int64_t batchDim = 0;
|
||||
for (int s = optionalIndicesList.size(); batchDim < s; ++batchDim)
|
||||
if (!isa<Torch::NoneType>(optionalIndicesList[batchDim].getType()))
|
||||
break;
|
||||
|
||||
// Padding the indices list with none values.
|
||||
if (sizeOptionalIndicesList < inputRank) {
|
||||
for (unsigned i = 0; i < (inputRank - sizeOptionalIndicesList); i++)
|
||||
optionalIndicesList.push_back(torchCstNone);
|
||||
}
|
||||
int64_t nextNone = batchDim;
|
||||
for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone)
|
||||
if (isa<Torch::NoneType>(optionalIndicesList[nextNone].getType()))
|
||||
break;
|
||||
|
||||
SmallVector<int64_t> indexBroadcastShapeInt{
|
||||
optionalIndicesList[nonNoneIndexTensorDim[0]]
|
||||
.getType()
|
||||
.cast<BaseTensorType>()
|
||||
.getSizes()};
|
||||
SmallVector<Value> indexBroadcastShapeValue;
|
||||
if (numNonNoneIndices == 2) {
|
||||
computeBroadcastShape(rewriter, loc,
|
||||
optionalIndicesList[nonNoneIndexTensorDim[0]],
|
||||
optionalIndicesList[nonNoneIndexTensorDim[1]],
|
||||
indexBroadcastShapeInt, indexBroadcastShapeValue);
|
||||
} else {
|
||||
// It means there's only one index tensor and broadcast shape is same as
|
||||
// that index tensor' shape.
|
||||
for (unsigned i = 0; i < indexBroadcastShapeInt.size(); i++) {
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
indexBroadcastShapeValue.push_back(rewriter.createOrFold<AtenSizeIntOp>(
|
||||
loc, optionalIndicesList[nonNoneIndexTensorDim[0]], dim));
|
||||
}
|
||||
}
|
||||
for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone)
|
||||
if (!isa<Torch::NoneType>(optionalIndicesList[nextNone].getType()))
|
||||
batchDim = 0;
|
||||
|
||||
Type indicesDtype = optionalIndicesList[nonNoneIndexTensorDim[0]]
|
||||
.getType()
|
||||
.cast<BaseTensorType>()
|
||||
.getDtype();
|
||||
// Indices are extended, catted, and collapsed into a [batch, depth] tensor:
|
||||
Value indices = combinePutIndices(loc, indicesList, rewriter);
|
||||
|
||||
// This implementation is done to get the scatter indices:
|
||||
// Bove batch dimensions to the front and collapse into a single dim:
|
||||
values =
|
||||
collapseAndMoveBatchDims(loc, values, batchDim, numBatchDims, rewriter);
|
||||
valuesType = cast<Torch::ValueTensorType>(values.getType());
|
||||
|
||||
// def get_broadcast_shape(tensors):
|
||||
// return list(torch.broadcast_tensors(*tensors)[0].shape)
|
||||
|
||||
// def get_input_shape_to_output_shape_map(optional_index_tensors:
|
||||
// list[Optional[torch.Tensor]]):
|
||||
// index_tensors = list(filter(lambda x: x is not None,
|
||||
// optional_index_tensors)) broadcast_rank =
|
||||
// len(get_broadcast_shape(index_tensors)) num_of_index_tensors =
|
||||
// len(index_tensors) index_of_first_index_tensor: Optional[int] = None
|
||||
// result = {}
|
||||
// for i, index in enumerate(optional_index_tensors):
|
||||
// if index is None:
|
||||
// val = i
|
||||
// if index_of_first_index_tensor is not None:
|
||||
// val += broadcast_rank - num_of_index_tensors
|
||||
// result[i] = [val]
|
||||
// else:
|
||||
// if index_of_first_index_tensor is None:
|
||||
// index_of_first_index_tensor = i
|
||||
// output_indices = list(range(index_of_first_index_tensor,
|
||||
// index_of_first_index_tensor +
|
||||
// broadcast_rank))
|
||||
// result[i] = output_indices
|
||||
// return result
|
||||
|
||||
// def spotlight_indices(shape, indices: list[int]):
|
||||
// """Turn all values in `shape` to `1` except for those with index in
|
||||
// `indices`.""" shape = shape.copy() for i in range(len(shape)):
|
||||
// if i not in indices:
|
||||
// shape[i] = 1
|
||||
// return shape
|
||||
|
||||
// def get_final_shape(input, optional_index_tensors:
|
||||
// list[Optional[torch.Tensor]]):
|
||||
// index_tensors = list(filter(lambda x: x is not None,
|
||||
// optional_index_tensors)) index_tensors_broadcast_shape =
|
||||
// get_broadcast_shape(index_tensors) result = []
|
||||
// handled_index_tensor_space = False
|
||||
// for e, i in enumerate(input.shape):
|
||||
// if optional_index_tensors[e] is None:
|
||||
// result.append(i)
|
||||
// else:
|
||||
// if not handled_index_tensor_space:
|
||||
// handled_index_tensor_space = True
|
||||
// result += index_tensors_broadcast_shape
|
||||
// return result
|
||||
|
||||
// def get_scatter_indices(input, optional_index_tensors:
|
||||
// list[Optional[torch.Tensor]]):
|
||||
// assert len(input.size()) == len(optional_index_tensors), "Pad indices
|
||||
// with None" shape_map =
|
||||
// get_input_shape_to_output_shape_map(optional_index_tensors)
|
||||
// index_tensors = list(filter(lambda x: x is not None,
|
||||
// optional_index_tensors)) index_tensors_broadcast_shape =
|
||||
// get_broadcast_shape(index_tensors) final_shape =
|
||||
// get_final_shape(input, optional_index_tensors)
|
||||
|
||||
// broadcasted_index_tensors = []
|
||||
// for e, optional_index_tensor in enumerate(optional_index_tensors):
|
||||
// if optional_index_tensor is None:
|
||||
// tensor_to_broadcast = torch.arange(0, input.size(e))
|
||||
// else:
|
||||
// tensor_to_broadcast =
|
||||
// optional_index_tensor.broadcast_to(index_tensors_broadcast_shape)
|
||||
|
||||
// broadcasted_index_tensor = \
|
||||
// tensor_to_broadcast.reshape(spotlight_indices(final_shape, shape_map[e]))\
|
||||
// .broadcast_to(final_shape)\
|
||||
// .flatten()
|
||||
// broadcasted_index_tensors.append(broadcasted_index_tensor)
|
||||
|
||||
// return torch.stack(broadcasted_index_tensors, dim=0).t()
|
||||
|
||||
auto scatterIndicesInfo =
|
||||
getScatterIndices(op, rewriter, indicesDtype, optionalIndicesList,
|
||||
indexBroadcastShapeInt, indexBroadcastShapeValue);
|
||||
if (failed(scatterIndicesInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "cannot generate scatter indices for index put op");
|
||||
}
|
||||
Value indexTensor = *scatterIndicesInfo;
|
||||
|
||||
// Flattening the values tensor.
|
||||
Value torchCstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
// Materialize out the length-1 dimensions:
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value flattenedValuesTensorLastDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(valuesTensorType.getSizes().size() - 1));
|
||||
SmallVector<int64_t> valuesShapeInt{valuesTensorType.getSizes()};
|
||||
int64_t valuesCount = 1;
|
||||
if (llvm::all_of(valuesShapeInt,
|
||||
[](int64_t shape) { return shape != kUnknownSize; })) {
|
||||
for (int64_t i : valuesShapeInt)
|
||||
valuesCount *= i;
|
||||
} else {
|
||||
valuesCount = kUnknownSize;
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
llvm::SmallVector<int64_t> valuesShape{valuesType.getSizes().front()};
|
||||
llvm::SmallVector<Value> valuesDims;
|
||||
valuesDims.push_back(
|
||||
rewriter.create<Torch::AtenSizeIntOp>(loc, values, zero));
|
||||
|
||||
int vDim = 1;
|
||||
for (int i = 0, s = inputType.getSizes().size(); i < s; ++i) {
|
||||
if (i < optionalIndicesCount &&
|
||||
!isa<Torch::NoneType>(optionalIndicesList[i].getType())) {
|
||||
valuesDims.push_back(one);
|
||||
valuesShape.push_back(1);
|
||||
continue;
|
||||
}
|
||||
|
||||
Value k = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(vDim));
|
||||
valuesDims.push_back(
|
||||
rewriter.create<Torch::AtenSizeIntOp>(loc, values, k));
|
||||
valuesShape.push_back(inputType.getSizes()[i]);
|
||||
vDim++;
|
||||
}
|
||||
auto flattenedValuesTensorType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef(valuesCount), valuesTensorType.getDtype());
|
||||
Value flattenedValuesTensor = rewriter.create<AtenFlattenUsingIntsOp>(
|
||||
loc, flattenedValuesTensorType, op.getValues(), torchCstZero,
|
||||
flattenedValuesTensorLastDim);
|
||||
values = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc,
|
||||
typeConverter->convertType(flattenedValuesTensor.getType()),
|
||||
flattenedValuesTensor);
|
||||
|
||||
Value valuesDimsList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||
valuesDims);
|
||||
|
||||
valuesType = rewriter.getType<Torch::ValueTensorType>(
|
||||
valuesShape, valuesType.getOptionalDtype());
|
||||
values =
|
||||
rewriter.create<AtenViewOp>(loc, valuesType, values, valuesDimsList);
|
||||
|
||||
// `TMTensor::ScatterOp` expects indices of element type i32.
|
||||
Value indices = convertTensorToDtype(
|
||||
rewriter, loc, indexTensor,
|
||||
indices = convertTensorToDtype(
|
||||
rewriter, loc, indices,
|
||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||
|
||||
input = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(input.getType()), input);
|
||||
values = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(values.getType()), values);
|
||||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
|
||||
|
@ -931,7 +828,8 @@ public:
|
|||
// 3.) `input` is mapped to `original` in scatter op.
|
||||
bool invalidInputTypeFound = false;
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, values, indices, input, /*uniqueIndices=*/false,
|
||||
rewriter, loc, values, indices, input, indicesMap,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||
Value inputElement) {
|
||||
Value yieldValue = valuesElement;
|
||||
|
@ -1150,6 +1048,7 @@ public:
|
|||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, /*updates=*/gradOutputFlattened,
|
||||
/*indices=*/indicesCollapsed, /*original=*/outputTensor,
|
||||
/*dimensionsMap=*/createDefaultDimMap(indicesCollapsed),
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||
Value inputElement) {
|
||||
|
@ -1292,6 +1191,7 @@ public:
|
|||
srcType.getElementType(), /*init_element=*/normalizationValue);
|
||||
self = createTMTensorScatterOp(
|
||||
rewriter, loc, normalizations, indices, self,
|
||||
/*dimensionsMap=*/createDefaultDimMap(indices),
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
b.create<TMTensor::YieldOp>(loc, update);
|
||||
|
@ -1299,6 +1199,7 @@ public:
|
|||
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
counts = createTMTensorScatterOp(
|
||||
rewriter, loc, normalizations, indices, counts,
|
||||
/*dimensionsMap=*/createDefaultDimMap(indices),
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
b.create<TMTensor::YieldOp>(loc, update);
|
||||
|
@ -1309,7 +1210,7 @@ public:
|
|||
// Create final operation
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, updates, indices, self,
|
||||
/*uniqueIndices=*/false,
|
||||
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
Value result;
|
||||
if (reduceEnum == torch_upstream::ReductionType::SUM ||
|
||||
|
@ -1353,6 +1254,7 @@ public:
|
|||
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
counts = createTMTensorScatterOp(
|
||||
rewriter, loc, updates, indices, counts,
|
||||
/*dimensionsMap=*/createDefaultDimMap(indices),
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
Value result;
|
||||
|
|
|
@ -509,12 +509,32 @@ LogicalResult ScanOp::fold(FoldAdaptor adaptor,
|
|||
//===----------------------------------------------------------------------===//
|
||||
// ScatterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static Type getComplexElementTypeOrSelf(Type ty) {
|
||||
if (auto complex = dyn_cast_or_null<ComplexType>(ty))
|
||||
return complex.getElementType();
|
||||
return ty;
|
||||
}
|
||||
|
||||
static bool isInvalid(ArrayRef<int64_t> dimsPos, int64_t rank) {
|
||||
// early exit.
|
||||
if (static_cast<int64_t>(dimsPos.size()) > rank)
|
||||
return true;
|
||||
DenseSet<int64_t> uniqued;
|
||||
for (int64_t dim : dimsPos)
|
||||
uniqued.insert(dim);
|
||||
if (static_cast<int64_t>(dimsPos.size()) != uniqued.size())
|
||||
return true;
|
||||
return llvm::any_of(
|
||||
dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; });
|
||||
}
|
||||
|
||||
LogicalResult ScatterOp::verify() {
|
||||
Operation *op = getOperation();
|
||||
if (getInputs().size() != 2) {
|
||||
return emitOpError("expected two input operands");
|
||||
return op->emitOpError("expected two input operands");
|
||||
}
|
||||
if (getOutputs().size() != 1) {
|
||||
return emitOpError("expected one output operand");
|
||||
return op->emitOpError("expected one output operand");
|
||||
}
|
||||
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
|
||||
return t1.getShape()[dim] == t2.getShape()[dim];
|
||||
|
@ -526,10 +546,19 @@ LogicalResult ScatterOp::verify() {
|
|||
return emitOpError("expected indices to be of rank 2 of i32 element type");
|
||||
}
|
||||
auto indexDepth = getIndexDepth();
|
||||
if (indexDepth == ShapedType::kDynamic) {
|
||||
if (ShapedType::isDynamic(indexDepth)) {
|
||||
return emitOpError("expected index depth is static");
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> dimMap = getDimensionMap();
|
||||
if (static_cast<int64_t>(dimMap.size()) != indexDepth) {
|
||||
return op->emitOpError("invalid number of dimension map entries ");
|
||||
}
|
||||
|
||||
auto originalType = getOriginalType();
|
||||
if (isInvalid(dimMap, originalType.getRank()))
|
||||
return op->emitOpError("dimension map is invalid");
|
||||
|
||||
// The first dimension of the indices should match the first dimension of the
|
||||
// output. They indicate to the number of updates.
|
||||
auto updateType = getUpdateType();
|
||||
|
@ -540,7 +569,6 @@ LogicalResult ScatterOp::verify() {
|
|||
return emitOpError(
|
||||
"mismatch in shape of indices and update value at dim#0");
|
||||
}
|
||||
auto originalType = getOriginalType();
|
||||
if (updateType.getRank() - 1 > originalType.getRank()) {
|
||||
return emitOpError(
|
||||
"update value rank exceeds the rank of the original value");
|
||||
|
@ -553,7 +581,7 @@ LogicalResult ScatterOp::verify() {
|
|||
"index depth and update value does not cover rank of original value");
|
||||
}
|
||||
|
||||
// Validate the non-indexed update dims covier the full slice size of the
|
||||
// Validate the non-indexed update dims cover the full slice size of the
|
||||
// original tensor.
|
||||
int64_t fullSliceDims = originalType.getRank() - indexDepth;
|
||||
for (auto it :
|
||||
|
@ -562,10 +590,11 @@ LogicalResult ScatterOp::verify() {
|
|||
updateType.getRank()))) {
|
||||
int64_t originalDim = std::get<0>(it);
|
||||
int64_t updateDim = std::get<1>(it);
|
||||
if (updateType.getDimSize(updateDim) !=
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return emitOpError("mismatch in shape of update value dim#")
|
||||
<< updateDim << " and original value at dim#" << originalDim;
|
||||
if (!originalType.isDynamicDim(originalDim) &&
|
||||
updateType.getDimSize(updateDim) >
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return op->emitOpError("shape of update value dim#")
|
||||
<< updateDim << " exceeds original value at dim#" << originalDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -576,23 +605,25 @@ LogicalResult ScatterOp::verify() {
|
|||
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
|
||||
int64_t originalDim = std::get<0>(it);
|
||||
int64_t updateDim = std::get<1>(it);
|
||||
if (updateType.getDimSize(updateDim) >
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return emitOpError("indexed shape of update value dim#")
|
||||
if (!originalType.isDynamicDim(originalDim) &&
|
||||
updateType.getDimSize(updateDim) >
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return op->emitOpError("indexed shape of update value dim#")
|
||||
<< updateDim << " exceeds original value at dim#" << originalDim
|
||||
<< " " << updateType.getDimSize(updateDim) << " "
|
||||
<< originalType.getDimSize(originalDim);
|
||||
}
|
||||
}
|
||||
|
||||
Region &thisRegion = getRegion();
|
||||
Block *body = &thisRegion.front();
|
||||
Region ®ion = this->getRegion();
|
||||
Block *body = ®ion.front();
|
||||
if (body->getNumArguments() != 2) {
|
||||
return emitOpError("expected region to have two arguments");
|
||||
return op->emitOpError("expected region to have two arguments");
|
||||
}
|
||||
Type arg0Type = body->getArgument(0).getType();
|
||||
Type arg1Type = body->getArgument(1).getType();
|
||||
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
|
||||
if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() ||
|
||||
!getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) {
|
||||
return emitOpError(
|
||||
"expected region to have scalar argument of integer or float types");
|
||||
}
|
||||
|
@ -684,14 +715,16 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
|||
starts[it.index() + offset] = it.value();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> dimMap = getDimensionMap();
|
||||
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
|
||||
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
|
||||
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
||||
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
|
||||
Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
|
||||
|
||||
if (starts[i])
|
||||
cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
|
||||
starts[i] = cast;
|
||||
auto dim = dimMap[i];
|
||||
if (starts[dim])
|
||||
ret = b.create<arith::AddIOp>(loc, ret, starts[dim]);
|
||||
starts[dim] = ret;
|
||||
}
|
||||
|
||||
Value init = b.create<memref::LoadOp>(loc, original(), starts);
|
||||
|
|
|
@ -75,6 +75,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||
// and those constants get somewhat obscured by TorchToArith.
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||
|
|
|
@ -1166,3 +1166,4 @@ class IndexPutImplIndexWithNoneModule(torch.nn.Module):
|
|||
module_factory=lambda: IndexPutImplIndexWithNoneModule())
|
||||
def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7))
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
|
||||
|
@ -74,7 +74,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
func.func @scatter_update_scalar_1D(
|
||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||
%updates: tensor<3xi32>) -> tensor<8xi32> {
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
|
||||
outs(%original : tensor<8xi32>) {
|
||||
^bb0(%update: i32, %orig: i32): // no predecessors
|
||||
|
@ -92,7 +92,7 @@ func.func @scatter_update_scalar_1D(
|
|||
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||
// CHECK: %[[CST1:.*]] = arith.constant 1 : i32
|
||||
|
@ -104,7 +104,7 @@ func.func @scatter_update_scalar_1D(
|
|||
func.func @scatter_add_scalar_1D(
|
||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||
%updates: tensor<3xi32>) -> tensor<8xi32> {
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
|
||||
outs(%original : tensor<8xi32>) {
|
||||
^bb0(%update: i32, %orig: i32): // no predecessors
|
||||
|
|
|
@ -105,7 +105,7 @@ func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
|
|||
func.func @scatter_update_scalar_1D(
|
||||
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||
outs(%original : memref<8xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -131,7 +131,7 @@ func.func @scatter_update_scalar_1D(
|
|||
func.func @scatter_add_scalar_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0, 1>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -162,7 +162,7 @@ func.func @scatter_add_scalar_2D(
|
|||
func.func @scatter_update_slice_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||
%updates: memref<2x3xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -192,7 +192,7 @@ func.func @scatter_update_slice_2D(
|
|||
func.func @scatter_add_scalar_1D(
|
||||
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||
outs(%original : memref<8xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -221,7 +221,7 @@ func.func @scatter_add_scalar_1D(
|
|||
func.func @scatter_add_slice_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||
%updates: memref<2x3xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -251,7 +251,7 @@ func.func @scatter_add_slice_2D(
|
|||
func.func @scatter_update_scalar_dynamic_1D(
|
||||
%original: memref<?xi32>, %indices: memref<?x1xi32>,
|
||||
%updates: memref<?xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
|
||||
outs(%original : memref<?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -277,7 +277,7 @@ func.func @scatter_update_scalar_dynamic_1D(
|
|||
func.func @scatter_add_scalar_dynamic_2D(
|
||||
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
|
||||
%updates: memref<?xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0, 1>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
|
||||
outs(%original : memref<?x?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -308,7 +308,7 @@ func.func @scatter_add_scalar_dynamic_2D(
|
|||
func.func @scatter_update_slice_dynamic_2D(
|
||||
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
|
||||
%updates: memref<?x?xi32>) {
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
|
||||
outs(%original : memref<?x?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
|
@ -335,6 +335,7 @@ func.func @scatter_update_slice_dynamic_2D(
|
|||
|
||||
func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
|
||||
tm_tensor.scatter
|
||||
{dimension_map= array<i64: 0, 1, 2>}
|
||||
unique_indices(true)
|
||||
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
|
||||
outs(%arg0 : memref<2x64x12xf32>) {
|
||||
|
|
|
@ -4,7 +4,7 @@ func.func @scatter_mixed_tensor_memref(
|
|||
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0, 1, 2>} unique_indices(true)
|
||||
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -20,7 +20,7 @@ func.func @scatter_mixed_tensor_memref(
|
|||
%update : tensor<?x?xf32>, %indices : memref<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0, 1, 2>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, memref<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -36,7 +36,7 @@ func.func @scatter_extra_outputs(
|
|||
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
// expected-error @+1 {{expected number of outputs to be same as the number of results}}
|
||||
%0, %1 = tm_tensor.scatter unique_indices(true)
|
||||
%0, %1 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -52,7 +52,7 @@ func.func @scatter_mixed_tensor_memref(
|
|||
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : memref<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : memref<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -68,7 +68,7 @@ func.func @scatter_output_type_mismatch(
|
|||
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<4x?xf32> {
|
||||
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'tensor<4x?xf32>'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -84,7 +84,7 @@ func.func @scatter_mixed_tensor_memref(
|
|||
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : memref<?x?xf32>) {
|
||||
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : memref<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -100,7 +100,7 @@ func.func @scatter_mixed_tensor_memref(
|
|||
%update : memref<?x?xf32>, %indices : memref<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) {
|
||||
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
|
||||
tm_tensor.scatter unique_indices(true)
|
||||
tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -116,7 +116,7 @@ func.func @scatter_dim_mismatch(
|
|||
%update : tensor<?x?xf32>, %indices : tensor<48x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<48x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -132,7 +132,7 @@ func.func @scatter_dim_mismatch(
|
|||
%update : tensor<64x?xf32>, %indices : tensor<48x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -148,7 +148,7 @@ func.func @scatter_dim_mismatch(
|
|||
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{op update value rank exceeds the rank of the original value}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
|
@ -162,16 +162,16 @@ func.func @scatter_dim_mismatch(
|
|||
|
||||
func.func @scatter_dim_mismatch(
|
||||
%update : tensor<?x4xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%original : tensor<?x3xf32>) -> tensor<?x3xf32> {
|
||||
// expected-error @+1 {{shape of update value dim#1 exceeds original value at dim#1}}
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x4xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
outs(%original : tensor<?x3xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
} -> tensor<?x3xf32>
|
||||
return %0 : tensor<?x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -180,7 +180,7 @@ func.func @scatter_region_type_mismatch(
|
|||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// expected-error @+1 {{expected region to have scalar argument of integer or float types}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi32>) {
|
||||
^bb0(%arg1: index, %arg2: index):
|
||||
|
@ -197,7 +197,7 @@ func.func @scatter_region_type_mismatch(
|
|||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi32>) {
|
||||
^bb0(%arg1: i64, %arg2: i32):
|
||||
|
@ -214,7 +214,7 @@ func.func @scatter_region_type_mismatch(
|
|||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi32>) {
|
||||
^bb0(%arg1: i32, %arg2: i64):
|
||||
|
@ -231,7 +231,7 @@ func.func @scatter_region_type_mismatch(
|
|||
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i32, %arg2: i64):
|
||||
|
@ -248,7 +248,7 @@ func.func @scatter_region_type_mismatch(
|
|||
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{expected region to have two arguments}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64, %arg3 : i64):
|
||||
|
@ -264,7 +264,7 @@ func.func @scatter_region_type_mismatch(
|
|||
func.func @scatter_yield_mismatch(
|
||||
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
|
@ -281,7 +281,7 @@ func.func @scatter_yield_mismatch(
|
|||
func.func @scatter_yield_mismatch(
|
||||
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
|
@ -299,7 +299,7 @@ func.func @scatter_index_depth_dynamic(
|
|||
%update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{expected index depth is static}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xi64>, tensor<?x?xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
|
@ -316,7 +316,7 @@ func.func @scatter_original_rank_mismatch(
|
|||
%update : tensor<?xi64>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
|
||||
// expected-error @+1 {{op index depth and update value does not cover rank of original value}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
%0 = tm_tensor.scatter {dimension_map= array<i64: 0>} unique_indices(true)
|
||||
ins(%update, %indices : tensor<?xi64>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64):
|
||||
|
|
Loading…
Reference in New Issue