[MLIR][TORCH] Add support for multiple indexing tensors for aten.index.Tensor (#1097)

- Includes a canonicalizer for `aten.add.t`needed for successfully lowering the shape function
 - Only offers support for statically sized index tensors when there is more than one
 - Dynamic shape support remains for single indexing tensors
pull/1120/head
Quinn Dawkins 2022-07-28 19:00:02 -04:00 committed by GitHub
parent b36a17c9d2
commit 11a8901078
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 357 additions and 53 deletions

View File

@ -6619,6 +6619,7 @@ def Torch_AtenAddTOp : Torch_Op<"aten.add.t", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasCanonicalizer = 1;
} }
def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [ def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [

View File

@ -244,6 +244,21 @@ public:
}; };
} // namespace } // namespace
// IndexTensor for multiple input tensors broadcasts their shapes to a common
// shape and then replaces the indexed dims with the indices given by the
// indexing tensors:
// x[i_1, i_2, ..., i_M] = result
// result[...] = x[i_1[...], i_2[...], ..., i_M[...]]
//
// where the result shape is computed as follows:
// 1. broadcast i_1, i_2, ..., i_M to a common shape
// 2. if i_1, i_2, ..., i_M is not contiguous, transpose the broadcasted
// shape to the beginning of the result shape, while removing the
// unchanged dims (marked by None)
// 3. Otherwise replace the indexed dims with the broadcasted shape
//
// e.g. x: [2, 3]
// x[[4], [6, 1]] -> x[6, 4]
namespace { namespace {
class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> { class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
public: public:
@ -251,6 +266,7 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
@ -266,78 +282,165 @@ public:
SmallVector<Value> indicesVal = SmallVector<Value> indicesVal =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple); getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);
int indexTensorDim = -1; // Identify the indices with non-None index tensors and determine if they
// are contiguous within the input list.
SmallVector<int> indexTensorDims;
SmallVector<Value> indexTensors;
bool contiguous = true;
for (auto i : llvm::seq(0, (int)indicesVal.size())) { for (auto i : llvm::seq(0, (int)indicesVal.size())) {
Value index = indicesVal[i]; Value index = indicesVal[i];
if (!index || failed(checkNotNone(rewriter, op, index))) if (!index || failed(checkNotNone(rewriter, op, index)))
continue; continue;
if (indexTensorDim >= 0) { if (!indexTensorDims.empty() && indexTensorDims.back() != i - 1)
return rewriter.notifyMatchFailure( contiguous = false;
op, "unimplemented: only one index tensor allowed"); indexTensorDims.push_back(i);
} indexTensors.push_back(index);
indexTensorDim = i;
} }
if (indexTensorDim == -1) { if (indexTensors.empty()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: index tensor must not be None"); op, "aten.index.Tensor: index tensor must not be None");
} }
Value indexTensor = indicesVal[indexTensorDim];
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType()) ->convertType(op->getResult(0).getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();
int inputRank = inputType.getRank(); int inputRank = inputType.getRank();
int indexTensorRank = indexTensorType.getRank(); int resultRank = resultType.getRank();
int firstIndexDim = indexTensorDims[0];
int replacedIndexCount = indexTensorDims.size();
int64_t startIndex = contiguous ? firstIndexDim : 0;
// Currently we only support statically sized index tensors
// when there is more than one index tensor.
// TODO: Add support for dynamic size index tensors. This will probably
// require broadcasting the index tensors to a common shape.
SmallVector<Value> broadcastedIndexShape;
if (indexTensors.size() > 1) {
int maxRank = -1;
for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
}
// Because we are assuming static shapes, we can get the shape of the
// broadcasted index tensors from the shape refinement pass
auto refinedResultShape = resultType.getShape();
for (auto i : llvm::seq(startIndex, startIndex + maxRank)) {
auto resultDimSize = refinedResultShape[i];
if (ShapedType::isDynamic(resultDimSize)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensors must have static shape if "
"there is more than one index tensor");
}
broadcastedIndexShape.push_back(
getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType()));
}
} else {
// For a single indexing tensor we can simply use its (dynamic) sizes
broadcastedIndexShape =
getTensorSizes(rewriter, loc, indexTensors.front());
}
// This result shape calculation assumes that there is only one // This result shape calculation assumes that there is only one
// index tensor of the input tensor. The calculation for arbitrary inputs is // index tensor, or all of the index tensors are statically shaped.
// much more complex. int broadcastRank = broadcastedIndexShape.size();
SmallVector<Value> resultShape;
for (auto i : llvm::seq(0, indexTensorDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
for (auto i : llvm::seq(0, indexTensorRank)) {
resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i));
}
for (auto i : llvm::seq(indexTensorDim + 1, inputRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
int resultRank = resultShape.size();
SmallVector<Value> resultShape;
if (contiguous) {
for (auto i : llvm::seq(0, firstIndexDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
resultShape.append(broadcastedIndexShape);
for (auto i : llvm::seq((int)resultShape.size(), resultRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input,
i - broadcastRank + replacedIndexCount));
}
} else {
resultShape.append(broadcastedIndexShape);
int j = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (j < replacedIndexCount && i == indexTensorDims[j]) {
j++;
continue;
}
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
}
// Initialize the indexing maps for the generic op. Because we are assuming
// static shapes for the indexing tensors when there are more than 1, we can
// safely map all size 1 dims to 0 in the corresponding affine maps.
// TODO: For dynamic shapes, we have to either broadcast the index tensors
// to a common shape or introduce some form of control flow.
Value initTensor = Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType); rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
SmallVector<AffineExpr> indicesExpr, resultExpr; SmallVector<AffineMap> indexingMaps;
SmallVector<StringRef> iteratorTypes; SmallVector<StringRef> iteratorTypes;
for (auto i : llvm::seq(indexTensorDim, indexTensorDim + indexTensorRank)) for (auto indexTensor : indexTensors) {
indicesExpr.push_back(rewriter.getAffineDimExpr(i)); RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
auto indexTensorShape = indexTensorType.getShape();
int rank = indexTensorShape.size();
SmallVector<AffineExpr> indicesExpr;
for (auto dim : llvm::seq(0, rank)) {
if (indexTensorShape[dim] == 1) {
indicesExpr.push_back(rewriter.getAffineConstantExpr(0));
continue;
}
indicesExpr.push_back(
rewriter.getAffineDimExpr(startIndex + broadcastRank - rank + dim));
}
indexingMaps.push_back(
AffineMap::get(resultRank, 0, indicesExpr, op->getContext()));
}
SmallVector<AffineExpr> resultExpr;
for (auto i : llvm::seq(0, resultRank)) { for (auto i : llvm::seq(0, resultRank)) {
resultExpr.push_back(rewriter.getAffineDimExpr(i)); resultExpr.push_back(rewriter.getAffineDimExpr(i));
iteratorTypes.push_back(getParallelIteratorTypeName()); iteratorTypes.push_back(getParallelIteratorTypeName());
} }
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});
indexingMaps.push_back(
AffineMap::get(resultRank, 0, resultExpr, op->getContext()));
Value finalRes = Value finalRes =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, initTensor.getType(), indexTensor, initTensor, loc, initTensor.getType(), indexTensors, initTensor,
indexingMaps, iteratorTypes, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value index = castIntToIndex(b, loc, args[0]);
SmallVector<Value> extractionIndices; SmallVector<Value> extractionIndices;
int extra_dims = 0; if (contiguous) {
for (auto i : llvm::seq(0, inputRank)) { for (auto i : llvm::seq(0, firstIndexDim)) {
if (i == indexTensorDim) {
extractionIndices.push_back(index);
extra_dims += indexTensorRank - 1;
} else {
extractionIndices.push_back( extractionIndices.push_back(
b.create<linalg::IndexOp>(loc, i + extra_dims)); b.create<linalg::IndexOp>(loc, i));
}
for (auto i : llvm::seq(0, (int)indexTensorDims.size())) {
extractionIndices.push_back(
castIntToIndex(b, loc, args[i]));
}
for (auto i :
llvm::seq((int)extractionIndices.size(), inputRank)) {
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, i + broadcastRank - replacedIndexCount));
}
} else {
int indexCount = 0, unchanged = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (indexCount < replacedIndexCount &&
i == indexTensorDims[indexCount]) {
extractionIndices.push_back(
castIntToIndex(b, loc, args[indexCount++]));
continue;
}
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, broadcastRank + unchanged));
unchanged++;
} }
} }
Value extractedElement = b.create<tensor::ExtractOp>( Value extractedElement = b.create<tensor::ExtractOp>(

View File

@ -1479,6 +1479,35 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
}); });
} }
//===----------------------------------------------------------------------===//
// AtenAddTOp
//===----------------------------------------------------------------------===//
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
auto lhsListConstruct = op.a().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
return failure();
auto rhsListConstruct = op.b().getDefiningOp<Torch::PrimListConstructOp>();
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
return failure();
SmallVector<Value> concatenatedList;
for (auto a : lhsListConstruct.getOperands()) {
concatenatedList.push_back(a);
}
for (auto b : rhsListConstruct.getOperands()) {
concatenatedList.push_back(b);
}
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(op, op.getType(),
concatenatedList);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenEqIntListOp // AtenEqIntListOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -6590,30 +6590,30 @@ module {
%10 = torch.aten.len.t %arg1 : !torch.list<optional<list<int>>> -> !torch.int %10 = torch.aten.len.t %arg1 : !torch.list<optional<list<int>>> -> !torch.int
%11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list<int> %11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list<int>
%12 = torch.prim.min.self_int %11 : !torch.list<int> -> !torch.int %12 = torch.prim.min.self_int %11 : !torch.list<int> -> !torch.int
%13:3 = torch.prim.Loop %12, %true, init(%true, %int-1, %int-1) { %13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) {
^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.int): ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int):
%16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<optional<list<int>>>, !torch.int -> !torch.optional<list<int>> %16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<optional<list<int>>>, !torch.int -> !torch.optional<list<int>>
%17 = torch.aten.__isnot__ %16, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool %17 = torch.aten.__isnot__ %16, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%18:3 = torch.prim.If %17 -> (!torch.bool, !torch.int, !torch.int) { %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {
%19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool %19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool
%20:3 = torch.prim.If %19 -> (!torch.bool, !torch.int, !torch.int) { %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {
torch.prim.If.yield %arg3, %arg2, %arg2 : !torch.bool, !torch.int, !torch.int torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int
} else { } else {
%21 = torch.aten.sub.int %arg2, %arg5 : !torch.int, !torch.int -> !torch.int %21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool %22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) { %23 = torch.prim.If %22 -> (!torch.bool) {
torch.prim.If.yield %false : !torch.bool torch.prim.If.yield %false : !torch.bool
} else { } else {
torch.prim.If.yield %arg3 : !torch.bool torch.prim.If.yield %arg3 : !torch.bool
} }
torch.prim.If.yield %23, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int
} }
torch.prim.If.yield %20#0, %20#1, %20#2 : !torch.bool, !torch.int, !torch.int torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int
} else { } else {
torch.prim.If.yield %arg3, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int
} }
torch.prim.Loop.condition %true, iter(%18#0, %18#1, %18#2 : !torch.bool, !torch.int, !torch.int) torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int)
} : (!torch.int, !torch.bool, !torch.bool, !torch.int, !torch.int) -> (!torch.bool, !torch.int, !torch.int) } : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int)
%14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool %14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.list<int>) { %15 = torch.prim.If %14 -> (!torch.list<int>) {
%16 = torch.aten.add.t %6, %4 : !torch.list<int>, !torch.list<int> -> !torch.list<int> %16 = torch.aten.add.t %6, %4 : !torch.list<int>, !torch.list<int> -> !torch.list<int>

View File

@ -418,6 +418,7 @@ class SimplifyShapeCalculationsPass
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
AtenSizeOp::getCanonicalizationPatterns(patterns, context); AtenSizeOp::getCanonicalizationPatterns(patterns, context);
AtenLenTOp::getCanonicalizationPatterns(patterns, context); AtenLenTOp::getCanonicalizationPatterns(patterns, context);
AtenAddTOp::getCanonicalizationPatterns(patterns, context);
// TODO: Debug visitation order to make this more efficient. // TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice. // A single linear scan should suffice.

View File

@ -1016,6 +1016,7 @@ def atenpad(self: List[int], pad: List[int], mode: str = "constant", value: O
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value. Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions. Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions. Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors. Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions. Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions. ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
@ -1037,15 +1038,13 @@ def atenindexTensor(self: List[int], indices: List[Optional[List[int]]]) -
if len(unused_dim_sizes) == 0: if len(unused_dim_sizes) == 0:
return broadcasted_shape return broadcasted_shape
prev_index_tensor_location = -1
first_index_tensor_location = -1 first_index_tensor_location = -1
index_tensors_are_together = True index_tensors_are_together = True
for e, index_tensor_shape in enumerate(indices): for e, index_tensor_shape in enumerate(indices):
if index_tensor_shape is not None: if index_tensor_shape is not None:
if first_index_tensor_location == -1: if first_index_tensor_location == -1:
first_index_tensor_location = e first_index_tensor_location = e
prev_index_tensor_location = e elif e - first_index_tensor_location != 1:
elif e - prev_index_tensor_location != 1:
index_tensors_are_together = False index_tensors_are_together = False
if not index_tensors_are_together: if not index_tensors_are_together:

View File

@ -489,7 +489,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# List ops. # List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)") emit("aten::cat : (Tensor[], int) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])") emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
emit("aten::list.t : (t[]) -> (t[])") emit("aten::list.t : (t[]) -> (t[])")
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])") emit("aten::slice.t : (t[], int?, int?, int) -> (t[])")

View File

@ -1700,6 +1700,130 @@ def IndexTensorSelectDimModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class IndexTensorMultiInput(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([3, 3], torch.int64, True),
([3], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (index1, index2,))
@register_test_case(module_factory=lambda: IndexTensorMultiInput())
def IndexTensorMultiInput_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(3, (3, 3)), torch.randint(3, (3,)))
# ==============================================================================
class IndexTensorMultiInputOneDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([6, 1], torch.int64, True),
([3], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (index1, index2,))
@register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim())
def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), torch.randint(3, (3,)))
# ==============================================================================
class IndexTensorMultiInputNonContiguous(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([4, 2], torch.int64, True),
([4, 2], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (index1, None, index2))
@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous())
def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 2)), torch.randint(1, (4, 2,)))
# ==============================================================================
class IndexTensorMultiInputThreeIndexers(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1, -1], torch.float32, True),
([8, 4, 2], torch.int64, True),
([8, 1, 1], torch.int64, True),
([4, 2], torch.int64, True),
])
def forward(self, x, index1, index2, index3):
return torch.ops.aten.index(x, (None, None, index1, None, index2, index3))
@register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers())
def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 4, 4, 5, 3),
torch.randint(3, (8, 4, 2,)),
torch.randint(4, (8, 1, 1,)),
torch.randint(2, (4, 2,)))
# ==============================================================================
class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([2, 2], torch.int64, True),
([2], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (None, index1, index2, None))
@register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter())
def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (2, 2)), torch.randint(2, [2]))
# ==============================================================================
class SquareModule(torch.nn.Module): class SquareModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -635,6 +635,53 @@ func.func @torch.aten.__getitem__.t$invalid_index() -> !torch.int {
return %1 : !torch.int return %1 : !torch.int
} }
// Not canonicalized because of mutated lhs list
// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_lhs_mutated()
func.func @torch.aten.add.t$no_canonicalize_lhs_mutated() -> !torch.list<int> {
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %0, %int4 : !torch.list<int>, !torch.int -> !torch.list<int>
// CHECK: torch.aten.add.t
%3 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
return %3 : !torch.list<int>
}
// Not canonicalized because of mutated rhs list
// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_rhs_mutated()
func.func @torch.aten.add.t$no_canonicalize_rhs_mutated() -> !torch.list<int> {
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %1, %int4 : !torch.list<int>, !torch.int -> !torch.list<int>
// CHECK: torch.aten.add.t
%3 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
return %3 : !torch.list<int>
}
// CHECK-LABEL: func.func @torch.aten.add.t$concat(
// CHECK-SAME: %[[ARG0:.*]]: !torch.int,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.list<int> {
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[LIST]] : !torch.list<int>
func.func @torch.aten.add.t$concat(%arg0: !torch.int, %arg1: !torch.int) -> !torch.list<int> {
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
%2 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
return %2 : !torch.list<int>
}
// CHECK-LABEL: func.func @torch.aten.add.t$concat_empty(
// CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.list<int> {
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]] : (!torch.int) -> !torch.list<int>
// CHECK: return %[[LIST]] : !torch.list<int>
func.func @torch.aten.add.t$concat_empty(%arg0: !torch.int) -> !torch.list<int> {
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.add.t %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
return %2 : !torch.list<int>
}
// CHECK-LABEL: func.func @torch.aten.eq.int_list$fold$literals_of_different_sizes // CHECK-LABEL: func.func @torch.aten.eq.int_list$fold$literals_of_different_sizes
// CHECK: %[[RET:.*]] = torch.constant.bool false // CHECK: %[[RET:.*]] = torch.constant.bool false
// CHECK: return %[[RET]] : !torch.bool // CHECK: return %[[RET]] : !torch.bool