mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for multiple indexing tensors for aten.index.Tensor (#1097)
- Includes a canonicalizer for `aten.add.t`needed for successfully lowering the shape function - Only offers support for statically sized index tensors when there is more than one - Dynamic shape support remains for single indexing tensorspull/1120/head
parent
b36a17c9d2
commit
11a8901078
|
@ -6619,6 +6619,7 @@ def Torch_AtenAddTOp : Torch_Op<"aten.add.t", [
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [
|
def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [
|
||||||
|
|
|
@ -244,6 +244,21 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// IndexTensor for multiple input tensors broadcasts their shapes to a common
|
||||||
|
// shape and then replaces the indexed dims with the indices given by the
|
||||||
|
// indexing tensors:
|
||||||
|
// x[i_1, i_2, ..., i_M] = result
|
||||||
|
// result[...] = x[i_1[...], i_2[...], ..., i_M[...]]
|
||||||
|
//
|
||||||
|
// where the result shape is computed as follows:
|
||||||
|
// 1. broadcast i_1, i_2, ..., i_M to a common shape
|
||||||
|
// 2. if i_1, i_2, ..., i_M is not contiguous, transpose the broadcasted
|
||||||
|
// shape to the beginning of the result shape, while removing the
|
||||||
|
// unchanged dims (marked by None)
|
||||||
|
// 3. Otherwise replace the indexed dims with the broadcasted shape
|
||||||
|
//
|
||||||
|
// e.g. x: [2, 3]
|
||||||
|
// x[[4], [6, 1]] -> x[6, 4]
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
|
class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -251,6 +266,7 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
|
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -266,78 +282,165 @@ public:
|
||||||
SmallVector<Value> indicesVal =
|
SmallVector<Value> indicesVal =
|
||||||
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);
|
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);
|
||||||
|
|
||||||
int indexTensorDim = -1;
|
// Identify the indices with non-None index tensors and determine if they
|
||||||
|
// are contiguous within the input list.
|
||||||
|
SmallVector<int> indexTensorDims;
|
||||||
|
SmallVector<Value> indexTensors;
|
||||||
|
bool contiguous = true;
|
||||||
for (auto i : llvm::seq(0, (int)indicesVal.size())) {
|
for (auto i : llvm::seq(0, (int)indicesVal.size())) {
|
||||||
Value index = indicesVal[i];
|
Value index = indicesVal[i];
|
||||||
if (!index || failed(checkNotNone(rewriter, op, index)))
|
if (!index || failed(checkNotNone(rewriter, op, index)))
|
||||||
continue;
|
continue;
|
||||||
if (indexTensorDim >= 0) {
|
if (!indexTensorDims.empty() && indexTensorDims.back() != i - 1)
|
||||||
return rewriter.notifyMatchFailure(
|
contiguous = false;
|
||||||
op, "unimplemented: only one index tensor allowed");
|
indexTensorDims.push_back(i);
|
||||||
}
|
indexTensors.push_back(index);
|
||||||
indexTensorDim = i;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (indexTensorDim == -1) {
|
if (indexTensors.empty()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: index tensor must not be None");
|
op, "aten.index.Tensor: index tensor must not be None");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value indexTensor = indicesVal[indexTensorDim];
|
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||||
RankedTensorType indexTensorType =
|
|
||||||
indexTensor.getType().cast<RankedTensorType>();
|
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = getTypeConverter()
|
||||||
->convertType(op->getResult(0).getType())
|
->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
int inputRank = inputType.getRank();
|
int inputRank = inputType.getRank();
|
||||||
int indexTensorRank = indexTensorType.getRank();
|
int resultRank = resultType.getRank();
|
||||||
|
int firstIndexDim = indexTensorDims[0];
|
||||||
|
int replacedIndexCount = indexTensorDims.size();
|
||||||
|
int64_t startIndex = contiguous ? firstIndexDim : 0;
|
||||||
|
|
||||||
|
// Currently we only support statically sized index tensors
|
||||||
|
// when there is more than one index tensor.
|
||||||
|
// TODO: Add support for dynamic size index tensors. This will probably
|
||||||
|
// require broadcasting the index tensors to a common shape.
|
||||||
|
SmallVector<Value> broadcastedIndexShape;
|
||||||
|
if (indexTensors.size() > 1) {
|
||||||
|
int maxRank = -1;
|
||||||
|
for (auto indexTensor : indexTensors) {
|
||||||
|
RankedTensorType indexTensorType =
|
||||||
|
indexTensor.getType().cast<RankedTensorType>();
|
||||||
|
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Because we are assuming static shapes, we can get the shape of the
|
||||||
|
// broadcasted index tensors from the shape refinement pass
|
||||||
|
auto refinedResultShape = resultType.getShape();
|
||||||
|
for (auto i : llvm::seq(startIndex, startIndex + maxRank)) {
|
||||||
|
auto resultDimSize = refinedResultShape[i];
|
||||||
|
if (ShapedType::isDynamic(resultDimSize)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: index tensors must have static shape if "
|
||||||
|
"there is more than one index tensor");
|
||||||
|
}
|
||||||
|
broadcastedIndexShape.push_back(
|
||||||
|
getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For a single indexing tensor we can simply use its (dynamic) sizes
|
||||||
|
broadcastedIndexShape =
|
||||||
|
getTensorSizes(rewriter, loc, indexTensors.front());
|
||||||
|
}
|
||||||
|
|
||||||
// This result shape calculation assumes that there is only one
|
// This result shape calculation assumes that there is only one
|
||||||
// index tensor of the input tensor. The calculation for arbitrary inputs is
|
// index tensor, or all of the index tensors are statically shaped.
|
||||||
// much more complex.
|
int broadcastRank = broadcastedIndexShape.size();
|
||||||
SmallVector<Value> resultShape;
|
|
||||||
for (auto i : llvm::seq(0, indexTensorDim)) {
|
|
||||||
resultShape.push_back(getDimOp(rewriter, loc, input, i));
|
|
||||||
}
|
|
||||||
for (auto i : llvm::seq(0, indexTensorRank)) {
|
|
||||||
resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i));
|
|
||||||
}
|
|
||||||
for (auto i : llvm::seq(indexTensorDim + 1, inputRank)) {
|
|
||||||
resultShape.push_back(getDimOp(rewriter, loc, input, i));
|
|
||||||
}
|
|
||||||
int resultRank = resultShape.size();
|
|
||||||
|
|
||||||
|
SmallVector<Value> resultShape;
|
||||||
|
if (contiguous) {
|
||||||
|
for (auto i : llvm::seq(0, firstIndexDim)) {
|
||||||
|
resultShape.push_back(getDimOp(rewriter, loc, input, i));
|
||||||
|
}
|
||||||
|
resultShape.append(broadcastedIndexShape);
|
||||||
|
for (auto i : llvm::seq((int)resultShape.size(), resultRank)) {
|
||||||
|
resultShape.push_back(getDimOp(rewriter, loc, input,
|
||||||
|
i - broadcastRank + replacedIndexCount));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resultShape.append(broadcastedIndexShape);
|
||||||
|
int j = 0;
|
||||||
|
for (auto i : llvm::seq(0, inputRank)) {
|
||||||
|
if (j < replacedIndexCount && i == indexTensorDims[j]) {
|
||||||
|
j++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
resultShape.push_back(getDimOp(rewriter, loc, input, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the indexing maps for the generic op. Because we are assuming
|
||||||
|
// static shapes for the indexing tensors when there are more than 1, we can
|
||||||
|
// safely map all size 1 dims to 0 in the corresponding affine maps.
|
||||||
|
// TODO: For dynamic shapes, we have to either broadcast the index tensors
|
||||||
|
// to a common shape or introduce some form of control flow.
|
||||||
Value initTensor =
|
Value initTensor =
|
||||||
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
|
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
|
||||||
SmallVector<AffineExpr> indicesExpr, resultExpr;
|
SmallVector<AffineMap> indexingMaps;
|
||||||
SmallVector<StringRef> iteratorTypes;
|
SmallVector<StringRef> iteratorTypes;
|
||||||
|
|
||||||
for (auto i : llvm::seq(indexTensorDim, indexTensorDim + indexTensorRank))
|
for (auto indexTensor : indexTensors) {
|
||||||
indicesExpr.push_back(rewriter.getAffineDimExpr(i));
|
RankedTensorType indexTensorType =
|
||||||
|
indexTensor.getType().cast<RankedTensorType>();
|
||||||
|
auto indexTensorShape = indexTensorType.getShape();
|
||||||
|
int rank = indexTensorShape.size();
|
||||||
|
SmallVector<AffineExpr> indicesExpr;
|
||||||
|
for (auto dim : llvm::seq(0, rank)) {
|
||||||
|
if (indexTensorShape[dim] == 1) {
|
||||||
|
indicesExpr.push_back(rewriter.getAffineConstantExpr(0));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
indicesExpr.push_back(
|
||||||
|
rewriter.getAffineDimExpr(startIndex + broadcastRank - rank + dim));
|
||||||
|
}
|
||||||
|
indexingMaps.push_back(
|
||||||
|
AffineMap::get(resultRank, 0, indicesExpr, op->getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> resultExpr;
|
||||||
for (auto i : llvm::seq(0, resultRank)) {
|
for (auto i : llvm::seq(0, resultRank)) {
|
||||||
resultExpr.push_back(rewriter.getAffineDimExpr(i));
|
resultExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||||
}
|
}
|
||||||
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});
|
|
||||||
|
indexingMaps.push_back(
|
||||||
|
AffineMap::get(resultRank, 0, resultExpr, op->getContext()));
|
||||||
|
|
||||||
Value finalRes =
|
Value finalRes =
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::GenericOp>(
|
.create<linalg::GenericOp>(
|
||||||
loc, initTensor.getType(), indexTensor, initTensor,
|
loc, initTensor.getType(), indexTensors, initTensor,
|
||||||
indexingMaps, iteratorTypes,
|
indexingMaps, iteratorTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value index = castIntToIndex(b, loc, args[0]);
|
|
||||||
SmallVector<Value> extractionIndices;
|
SmallVector<Value> extractionIndices;
|
||||||
int extra_dims = 0;
|
if (contiguous) {
|
||||||
for (auto i : llvm::seq(0, inputRank)) {
|
for (auto i : llvm::seq(0, firstIndexDim)) {
|
||||||
if (i == indexTensorDim) {
|
|
||||||
extractionIndices.push_back(index);
|
|
||||||
extra_dims += indexTensorRank - 1;
|
|
||||||
} else {
|
|
||||||
extractionIndices.push_back(
|
extractionIndices.push_back(
|
||||||
b.create<linalg::IndexOp>(loc, i + extra_dims));
|
b.create<linalg::IndexOp>(loc, i));
|
||||||
|
}
|
||||||
|
for (auto i : llvm::seq(0, (int)indexTensorDims.size())) {
|
||||||
|
extractionIndices.push_back(
|
||||||
|
castIntToIndex(b, loc, args[i]));
|
||||||
|
}
|
||||||
|
for (auto i :
|
||||||
|
llvm::seq((int)extractionIndices.size(), inputRank)) {
|
||||||
|
extractionIndices.push_back(b.create<linalg::IndexOp>(
|
||||||
|
loc, i + broadcastRank - replacedIndexCount));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int indexCount = 0, unchanged = 0;
|
||||||
|
for (auto i : llvm::seq(0, inputRank)) {
|
||||||
|
if (indexCount < replacedIndexCount &&
|
||||||
|
i == indexTensorDims[indexCount]) {
|
||||||
|
extractionIndices.push_back(
|
||||||
|
castIntToIndex(b, loc, args[indexCount++]));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
extractionIndices.push_back(b.create<linalg::IndexOp>(
|
||||||
|
loc, broadcastRank + unchanged));
|
||||||
|
unchanged++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Value extractedElement = b.create<tensor::ExtractOp>(
|
Value extractedElement = b.create<tensor::ExtractOp>(
|
||||||
|
|
|
@ -1479,6 +1479,35 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenAddTOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
|
||||||
|
auto lhsListConstruct = op.a().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
|
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto rhsListConstruct = op.b().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
|
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Value> concatenatedList;
|
||||||
|
for (auto a : lhsListConstruct.getOperands()) {
|
||||||
|
concatenatedList.push_back(a);
|
||||||
|
}
|
||||||
|
for (auto b : rhsListConstruct.getOperands()) {
|
||||||
|
concatenatedList.push_back(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(op, op.getType(),
|
||||||
|
concatenatedList);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenEqIntListOp
|
// AtenEqIntListOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -6590,30 +6590,30 @@ module {
|
||||||
%10 = torch.aten.len.t %arg1 : !torch.list<optional<list<int>>> -> !torch.int
|
%10 = torch.aten.len.t %arg1 : !torch.list<optional<list<int>>> -> !torch.int
|
||||||
%11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list<int>
|
%11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
%12 = torch.prim.min.self_int %11 : !torch.list<int> -> !torch.int
|
%12 = torch.prim.min.self_int %11 : !torch.list<int> -> !torch.int
|
||||||
%13:3 = torch.prim.Loop %12, %true, init(%true, %int-1, %int-1) {
|
%13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) {
|
||||||
^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.int):
|
^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int):
|
||||||
%16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<optional<list<int>>>, !torch.int -> !torch.optional<list<int>>
|
%16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<optional<list<int>>>, !torch.int -> !torch.optional<list<int>>
|
||||||
%17 = torch.aten.__isnot__ %16, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
%17 = torch.aten.__isnot__ %16, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||||
%18:3 = torch.prim.If %17 -> (!torch.bool, !torch.int, !torch.int) {
|
%18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {
|
||||||
%19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool
|
%19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool
|
||||||
%20:3 = torch.prim.If %19 -> (!torch.bool, !torch.int, !torch.int) {
|
%20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {
|
||||||
torch.prim.If.yield %arg3, %arg2, %arg2 : !torch.bool, !torch.int, !torch.int
|
torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int
|
||||||
} else {
|
} else {
|
||||||
%21 = torch.aten.sub.int %arg2, %arg5 : !torch.int, !torch.int -> !torch.int
|
%21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int
|
||||||
%22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool
|
%22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||||
%23 = torch.prim.If %22 -> (!torch.bool) {
|
%23 = torch.prim.If %22 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %false : !torch.bool
|
torch.prim.If.yield %false : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
torch.prim.If.yield %arg3 : !torch.bool
|
torch.prim.If.yield %arg3 : !torch.bool
|
||||||
}
|
}
|
||||||
torch.prim.If.yield %23, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int
|
torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int
|
||||||
}
|
}
|
||||||
torch.prim.If.yield %20#0, %20#1, %20#2 : !torch.bool, !torch.int, !torch.int
|
torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int
|
||||||
} else {
|
} else {
|
||||||
torch.prim.If.yield %arg3, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int
|
torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int
|
||||||
}
|
}
|
||||||
torch.prim.Loop.condition %true, iter(%18#0, %18#1, %18#2 : !torch.bool, !torch.int, !torch.int)
|
torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int)
|
||||||
} : (!torch.int, !torch.bool, !torch.bool, !torch.int, !torch.int) -> (!torch.bool, !torch.int, !torch.int)
|
} : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int)
|
||||||
%14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool
|
%14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool
|
||||||
%15 = torch.prim.If %14 -> (!torch.list<int>) {
|
%15 = torch.prim.If %14 -> (!torch.list<int>) {
|
||||||
%16 = torch.aten.add.t %6, %4 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
%16 = torch.aten.add.t %6, %4 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
||||||
|
|
|
@ -418,6 +418,7 @@ class SimplifyShapeCalculationsPass
|
||||||
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
||||||
AtenSizeOp::getCanonicalizationPatterns(patterns, context);
|
AtenSizeOp::getCanonicalizationPatterns(patterns, context);
|
||||||
AtenLenTOp::getCanonicalizationPatterns(patterns, context);
|
AtenLenTOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
AtenAddTOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
|
||||||
// TODO: Debug visitation order to make this more efficient.
|
// TODO: Debug visitation order to make this more efficient.
|
||||||
// A single linear scan should suffice.
|
// A single linear scan should suffice.
|
||||||
|
|
|
@ -1016,6 +1016,7 @@ def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: O
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
|
||||||
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
|
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
|
||||||
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
|
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
|
||||||
|
Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions.
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
|
||||||
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
|
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
|
||||||
|
@ -1037,15 +1038,13 @@ def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -
|
||||||
if len(unused_dim_sizes) == 0:
|
if len(unused_dim_sizes) == 0:
|
||||||
return broadcasted_shape
|
return broadcasted_shape
|
||||||
|
|
||||||
prev_index_tensor_location = -1
|
|
||||||
first_index_tensor_location = -1
|
first_index_tensor_location = -1
|
||||||
index_tensors_are_together = True
|
index_tensors_are_together = True
|
||||||
for e, index_tensor_shape in enumerate(indices):
|
for e, index_tensor_shape in enumerate(indices):
|
||||||
if index_tensor_shape is not None:
|
if index_tensor_shape is not None:
|
||||||
if first_index_tensor_location == -1:
|
if first_index_tensor_location == -1:
|
||||||
first_index_tensor_location = e
|
first_index_tensor_location = e
|
||||||
prev_index_tensor_location = e
|
elif e - first_index_tensor_location != 1:
|
||||||
elif e - prev_index_tensor_location != 1:
|
|
||||||
index_tensors_are_together = False
|
index_tensors_are_together = False
|
||||||
|
|
||||||
if not index_tensors_are_together:
|
if not index_tensors_are_together:
|
||||||
|
|
|
@ -489,7 +489,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
# List ops.
|
# List ops.
|
||||||
emit("aten::cat : (Tensor[], int) -> (Tensor)")
|
emit("aten::cat : (Tensor[], int) -> (Tensor)")
|
||||||
emit("aten::append.t : (t[], t) -> (t[])")
|
emit("aten::append.t : (t[], t) -> (t[])")
|
||||||
emit("aten::add.t : (t[], t[]) -> (t[])")
|
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||||
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
||||||
emit("aten::list.t : (t[]) -> (t[])")
|
emit("aten::list.t : (t[]) -> (t[])")
|
||||||
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])")
|
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])")
|
||||||
|
|
|
@ -1700,6 +1700,130 @@ def IndexTensorSelectDimModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorMultiInput(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([3, 3], torch.int64, True),
|
||||||
|
([3], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index1, index2):
|
||||||
|
return torch.ops.aten.index(x, (index1, index2,))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexTensorMultiInput())
|
||||||
|
def IndexTensorMultiInput_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 4, 3), torch.randint(3, (3, 3)), torch.randint(3, (3,)))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorMultiInputOneDim(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([6, 1], torch.int64, True),
|
||||||
|
([3], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index1, index2):
|
||||||
|
return torch.ops.aten.index(x, (index1, index2,))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim())
|
||||||
|
def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), torch.randint(3, (3,)))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorMultiInputNonContiguous(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([4, 2], torch.int64, True),
|
||||||
|
([4, 2], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index1, index2):
|
||||||
|
return torch.ops.aten.index(x, (index1, None, index2))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous())
|
||||||
|
def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 2)), torch.randint(1, (4, 2,)))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorMultiInputThreeIndexers(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
([8, 4, 2], torch.int64, True),
|
||||||
|
([8, 1, 1], torch.int64, True),
|
||||||
|
([4, 2], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index1, index2, index3):
|
||||||
|
return torch.ops.aten.index(x, (None, None, index1, None, index2, index3))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers())
|
||||||
|
def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 2, 4, 4, 5, 3),
|
||||||
|
torch.randint(3, (8, 4, 2,)),
|
||||||
|
torch.randint(4, (8, 1, 1,)),
|
||||||
|
torch.randint(2, (4, 2,)))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([2, 2], torch.int64, True),
|
||||||
|
([2], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index1, index2):
|
||||||
|
return torch.ops.aten.index(x, (None, index1, index2, None))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter())
|
||||||
|
def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (2, 2)), torch.randint(2, [2]))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SquareModule(torch.nn.Module):
|
class SquareModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -635,6 +635,53 @@ func.func @torch.aten.__getitem__.t$invalid_index() -> !torch.int {
|
||||||
return %1 : !torch.int
|
return %1 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not canonicalized because of mutated lhs list
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_lhs_mutated()
|
||||||
|
func.func @torch.aten.add.t$no_canonicalize_lhs_mutated() -> !torch.list<int> {
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%2 = torch.aten.append.t %0, %int4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
// CHECK: torch.aten.add.t
|
||||||
|
%3 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
||||||
|
return %3 : !torch.list<int>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not canonicalized because of mutated rhs list
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_rhs_mutated()
|
||||||
|
func.func @torch.aten.add.t$no_canonicalize_rhs_mutated() -> !torch.list<int> {
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%2 = torch.aten.append.t %1, %int4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
// CHECK: torch.aten.add.t
|
||||||
|
%3 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
||||||
|
return %3 : !torch.list<int>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.add.t$concat(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.int,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.list<int> {
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: return %[[LIST]] : !torch.list<int>
|
||||||
|
func.func @torch.aten.add.t$concat(%arg0: !torch.int, %arg1: !torch.int) -> !torch.list<int> {
|
||||||
|
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
||||||
|
return %2 : !torch.list<int>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.add.t$concat_empty(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.list<int> {
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: return %[[LIST]] : !torch.list<int>
|
||||||
|
func.func @torch.aten.add.t$concat_empty(%arg0: !torch.int) -> !torch.list<int> {
|
||||||
|
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%2 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
||||||
|
return %2 : !torch.list<int>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.eq.int_list$fold$literals_of_different_sizes
|
// CHECK-LABEL: func.func @torch.aten.eq.int_list$fold$literals_of_different_sizes
|
||||||
// CHECK: %[[RET:.*]] = torch.constant.bool false
|
// CHECK: %[[RET:.*]] = torch.constant.bool false
|
||||||
// CHECK: return %[[RET]] : !torch.bool
|
// CHECK: return %[[RET]] : !torch.bool
|
||||||
|
|
Loading…
Reference in New Issue