diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 385c5e6ec..60f3f3422 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3797,13 +3797,126 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTfConcatTensors.push_back(indicesTfOneDim.getResult()); } - // Right now only support multiple indexes with same shape - // TODO for different shape multiple indexes, add broadcast_to for small - // shape + auto getRankExtendedShape = + [](SmallVector inputShape, + SmallVector maxRank1DimShape) -> SmallVector { + SmallVector rankExtendedShape(maxRank1DimShape); + auto inputRank = inputShape.size(); + auto maxRank = maxRank1DimShape.size(); + auto startIdx = maxRank - inputRank; + for (size_t i = startIdx; i < maxRank; i++) { + rankExtendedShape[i] = inputShape[i - startIdx]; + } + return rankExtendedShape; + }; + + bool hasDiffShapedIndexes = false; for (auto indexShapeOneDim : indexesShape) { if (!llvm::equal(indexesShape[0], indexShapeOneDim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: Only support multi indexes with same shape"); + hasDiffShapedIndexes = true; + break; + } + } + + if (hasDiffShapedIndexes) { + int64_t maxRank = 1; + for (auto idxRank : indexesRank) { + if (idxRank > maxRank) + maxRank = idxRank; + } + // Tensor shape of max rank, each dim being 1 + SmallVector maxRank1DimShape; + for (int i = 0; i < maxRank; i++) + maxRank1DimShape.push_back(1); + // Tensor shape of max rank, each dim being the max dim. + SmallVector maxRankMaxDimShape(maxRank1DimShape); + + auto updateMaxRankMaxDimShape = + [&](SmallVector broadcastedShape) -> LogicalResult { + for (size_t i = 0; i < maxRankMaxDimShape.size(); i++) { + // check for malformed index tensors + if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 && + maxRankMaxDimShape[i] != broadcastedShape[i]) { + return failure(); + } + if (broadcastedShape[i] > maxRankMaxDimShape[i]) + maxRankMaxDimShape[i] = broadcastedShape[i]; + } + return success(); + }; + + for (size_t i = 0; i < indexesRank.size(); i++) { + // Reshape all index tensors to same maxRank + auto idxRank = indexesRank[i]; + auto unreshapedIdxTensor = indicesTfConcatTensors[i]; + SmallVector broadcastedShape = + getRankExtendedShape(indexesShape[i], maxRank1DimShape); + + if (idxRank < maxRank) { + auto idxType = + dyn_cast(indicesTfConcatTensors[i].getType()); + // indicesTfConcatTensors has a trailing [1] dim for the final concat. + auto broadcastedShapeTf(broadcastedShape); + broadcastedShapeTf.push_back(1); + auto reshapeOutputTy = RankedTensorType::get( + broadcastedShapeTf, idxType.getElementType()); + // Update the tensor array with the max rank-extended form + indicesTfConcatTensors[i] = rewriter.create( + op->getLoc(), reshapeOutputTy, unreshapedIdxTensor, + rewriter.getDenseI64ArrayAttr(broadcastedShapeTf)); + } + + // Construct the max rank broadcasted form of all index tensors with + // each index tensor. + if (updateMaxRankMaxDimShape(broadcastedShape).failed()) { + return rewriter.notifyMatchFailure( + op, "Malformed index tensors that have mismatched dim shapes"); + } + + // Every index now has the same rank but not yet same shape until + // tosa.tile below. + indexesShape[i] = broadcastedShape; + indexesRank[i] = maxRank; + } + + auto getTileOpShape = [&](SmallVector indexShape, + SmallVector &tileOpShape) -> bool { + bool needsTiling = false; + for (size_t i = 0; i < indexShape.size(); i++) { + if (1 == indexShape[i]) { + tileOpShape.push_back(maxRankMaxDimShape[i]); + needsTiling = true; + } else { + tileOpShape.push_back(1); + } + } + return needsTiling; + }; + + // Use tosa.tile to broadcast in multiple dims so all index tensors have + // the same shape. This materializes new tensors. + for (size_t i = 0; i < indexesRank.size(); i++) { + SmallVector tileOpShape; + bool needsTiling = getTileOpShape(indexesShape[i], tileOpShape); + + if (needsTiling) { + auto idxType = + dyn_cast(indicesTfConcatTensors[i].getType()); + // indicesTfConcatTensors has a trailing [1] dim for the final concat. + auto maxRankMaxDimShapeTf(maxRankMaxDimShape); + maxRankMaxDimShapeTf.push_back(1); + auto tileOpShapeTf(tileOpShape); + tileOpShapeTf.push_back(1); + auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf, + idxType.getElementType()); + auto reshapedIdxTensor = indicesTfConcatTensors[i]; + indicesTfConcatTensors[i] = rewriter.create( + op->getLoc(), tileOutputTy, reshapedIdxTensor, + rewriter.getDenseI64ArrayAttr(tileOpShapeTf)); + } + + // Every index tensor now has the same rank and shape + indexesShape[i] = maxRankMaxDimShape; } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e2ad3310e..fb215a303 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -30,6 +30,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { # this is added to check the torch.onnx.export -> import_onnx -> torch path "DeformConv2D_basic", "ReduceAnyDimFloatModule_basic", + "UnfoldModule_basic", } LINALG_CRASHING_SET = { @@ -1983,6 +1984,8 @@ TOSA_PASS_SET = { "TorchPrimLoopForLikeTensorArgModule_basic", "RenormModuleFloat32NegativeDim_basic", "RenormModuleFloat32_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", } MAKE_FX_TOSA_PASS_SET = ( @@ -2750,6 +2753,7 @@ ONNX_XFAIL_SET = { "ReduceAnyFloatModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", + "UnfoldModule_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -3189,7 +3193,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "IndexSelectWholeTensorModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", - "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorMultiInputContiguousCenter_basic", "IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousDynamic_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 552f51af1..082223631 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5646,3 +5646,27 @@ def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils): module.forward( torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3) ) + + +# ============================================================================== + + +class UnfoldModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.unfold = torch.nn.Unfold(kernel_size=(2, 3)) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input): + return self.unfold(input) + + +@register_test_case(module_factory=lambda: UnfoldModule()) +def UnfoldModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 3, 4))