diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index ddff294ec..c8f960e8a 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -273,6 +273,10 @@ LTC_XFAIL_SET = { "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", "IndexTensorSelectDimModule_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", "Matmul_dot", "Matmul_matvec", "MulIntModule_basic", diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index fd81ad5fd..c246593d2 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -583,10 +583,11 @@ public: int replacedIndexCount = indexTensorDims.size(); int64_t startIndex = contiguous ? firstIndexDim : 0; - // Currently we only support statically sized index tensors - // when there is more than one index tensor. - // TODO: Add support for dynamic size index tensors. This will probably - // require broadcasting the index tensors to a common shape. + // Currently we only support statically sized index tensors or dynamic size + // index tensors without overlapping dynamic dims when there is more than + // one index tensor. + // TODO: Add support for dynamic size index tensors with overlapping + // dynamic dims. SmallVector broadcastedIndexShape; if (indexTensors.size() > 1) { int maxRank = -1; @@ -602,12 +603,39 @@ public: for (auto i : llvm::seq(startIndex, startIndex + maxRank)) { auto resultDimSize = refinedResultShape[i]; if (ShapedType::isDynamic(resultDimSize)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: index tensors must have static shape if " - "there is more than one index tensor"); + SmallVector dynamicDims; + int64_t staticDimSize = -1; + for (auto indexTensor : indexTensors) { + RankedTensorType indexTensorType = + indexTensor.getType().cast(); + int64_t indexTensorRank = indexTensorType.getRank(); + if ((maxRank - indexTensorRank) > (i - startIndex)) + continue; + int64_t dim = i - startIndex - maxRank + indexTensorRank; + if (ShapedType::isDynamic(indexTensorType.getShape()[dim])) + dynamicDims.push_back(getDimOp(rewriter, loc, indexTensor, dim)); + else + staticDimSize = + std::max(staticDimSize, indexTensorType.getShape()[dim]); + } + if (dynamicDims.size() >= 2) + return rewriter.notifyMatchFailure( + op, + "unimplemented: index tensors with overlapping dynamic dims"); + if (staticDimSize > 1) { + Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, + rewriter.getIndexType()); + auto equalToRunning = rewriter.create( + loc, arith::CmpIPredicate::eq, cstStaticDimSize, + dynamicDims[0]); + rewriter.create(loc, equalToRunning, + "mismatched size for broadcast"); + } + broadcastedIndexShape.push_back(dynamicDims[0]); + } else { + broadcastedIndexShape.push_back(getConstant( + rewriter, loc, resultDimSize, rewriter.getIndexType())); } - broadcastedIndexShape.push_back( - getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType())); } } else { // For a single indexing tensor we can simply use its (dynamic) sizes diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 018977e2c..8a0b05ab2 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1748,6 +1748,125 @@ def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils): # ============================================================================== +class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, 1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, ( + None, + index1, + index2, + )) + + +@register_test_case( + module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic()) +def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), + torch.randint(3, (3, ))) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, 1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, ( + index1, + None, + index2, + )) + + +@register_test_case( + module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic()) +def IndexTensorMultiInputNonContiguousOneDimDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), + torch.randint(3, (3, ))) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, 2], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, ( + index2, + None, + index1, + )) + + +@register_test_case( + module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic()) +def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), torch.randint(2, (6, 2)), + torch.randint(3, (2, ))) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 1], torch.int64, True), + ([1, 3], torch.int64, True), + ([-1, 3], torch.int64, True), + ]) + def forward(self, x, index1, index2, index3): + return torch.ops.aten.index(x, (index1, index2, index3)) + + +@register_test_case(module_factory=lambda: + IndexTensorMultiInputNonContiguousMultipleStaticDims()) +def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 1)), + torch.randint(1, (1, 3)), torch.randint(1, (4, 3))) + + +# ============================================================================== + + class IndexTensorMultiInputNonContiguous(torch.nn.Module): def __init__(self): @@ -2591,7 +2710,7 @@ class AtenEmbeddingBagSumExample(torch.nn.Module): @export @annotate_args([ - None, + None, ([-1, -1], torch.float32, True), ([-1], torch.int64, True), ([-1], torch.int64, True), @@ -2613,7 +2732,7 @@ class Aten_EmbeddingBagExample(torch.nn.Module): @export @annotate_args([ - None, + None, ([-1, -1], torch.float32, True), ([-1], torch.int64, True), ([-1], torch.int64, True), @@ -2645,4 +2764,4 @@ class AtenToDeviceModule(torch.nn.Module): @register_test_case(module_factory=lambda: AtenToDeviceModule()) def AtenToDeviceModule_basic(module, tu: TestUtils): - module.forward(torch.randn(2, 4)) \ No newline at end of file + module.forward(torch.randn(2, 4)) diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index dd07e3848..fdb6742b1 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -233,3 +233,38 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index.Tensor +// CHECK-SAME: (%[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>, %[[ARG2:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[NONE]], %[[ARG2]] : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list> +// CHECK: %[[INDEX1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor +// CHECK: %[[INDEX2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[?],si64> -> tensor +// CHECK: %[[CST0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = tensor.dim %[[INDEX1]], %[[CST0]] : tensor +// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[INDEX2]], %[[CST0_0]] : tensor +// CHECK: %[[CST1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM2:.*]] = tensor.dim %[[T]], %[[CST1]] : tensor +// CHECK: %[[OUT_T:.*]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] : tensor +// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[INDEX1]], %[[INDEX2]] : tensor, tensor) outs(%[[OUT_T]] : tensor) { +// CHECK: ^bb0(%[[IN1:.*]]: i64, %[[IN2:.*]]: i64, %[[IN3:.*]]: f32): +// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IN1]] : i64 to index +// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index +// CHECK: %[[INDEX_3:.*]] = arith.index_cast %[[IN2]] : i64 to index +// CHECK: %[[RESULT:.*]] = tensor.extract %[[T]][%[[INDEX_1]], %[[INDEX_2]], %[[INDEX_3]]] : tensor +// CHECK: linalg.yield %[[RESULT]] : f32 +// CHECK: } -> tensor +// CHECK: %[[OUT_CAST:.*]] = tensor.cast %[[OUT]] : tensor to tensor +// CHECK: %[[VALUE_OUT_CAST:.*]] = torch_c.from_builtin_tensor %[[OUT_CAST]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[VALUE_OUT_CAST]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> { + %none = torch.constant.none + %1 = torch.prim.ListConstruct %arg1, %none, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list> + %2 = torch.aten.index.Tensor %arg0, %1 : !torch.vtensor<[?,?,?],f32>, !torch.list> -> !torch.vtensor<[?,?,?],f32> + return %2 : !torch.vtensor<[?,?,?],f32> +}