diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 3cf821944..40367138b 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -675,12 +675,12 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, return b.create(loc, valuesTy, values, outDimsList); } -class ConvertAten_IndexPutImplOp - : public OpConversionPattern { +class ConvertAtenIndexPutHackedTwinOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor, + matchAndRewrite(AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -699,17 +699,6 @@ public: return rewriter.notifyMatchFailure( op, "unimplemented: the values tensor type must have sizes."); - // The unsafe should be either `False` or `none`. - if (!op.getUnsafe().getType().isa()) { - bool unsafe; - if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe))) - return rewriter.notifyMatchFailure( - op, "unimplemented: unsafe must be a constant"); - else if (unsafe) - return rewriter.notifyMatchFailure( - op, "unimplemented: unsafe is expected to be false"); - } - // The accumulate should be a torch constant of boolean type. bool accumulate; if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) @@ -1621,8 +1610,8 @@ public: RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 010e7fce0..168515051 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3575,8 +3575,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - Aten_IndexPutImplOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // a = torch.tensor([[0, 1, 2, 3]]) // a[..., 1:] = torch.tensor([4, 5, 6]) @@ -5331,7 +5331,7 @@ public: INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATENOP_PATTERN(AtenAbsOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 73a67e14b..9fad15e13 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5523,23 +5523,6 @@ public: }; } // namespace -namespace { -// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op. -class DecomposeAtenIndexPutOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenIndexPutOp op, - PatternRewriter &rewriter) const override { - Value cstFalse = rewriter.create(op.getLoc(), false); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), - op.getAccumulate(), - /*unsafe=*/cstFalse); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenExpandAsOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -5635,44 +5618,6 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { }; } // namespace -namespace { -// Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl` -// op. -class DecomposeAtenIndexPutHackedTwinOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op, - PatternRewriter &rewriter) const override { - Value cstFalse = rewriter.create(op.getLoc(), false); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), - op.getAccumulate(), - /*unsafe=*/cstFalse); - return success(); - } -}; -} // namespace - -namespace { -// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl` -// op. -class DecomposeAten_UnsafeIndexPutHackedTwinOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op, - PatternRewriter &rewriter) const override { - Value cstFalse = rewriter.create(op.getLoc(), false); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), - op.getAccumulate(), - /*unsafe=*/cstFalse); - return success(); - } -}; -} // namespace - namespace { // Decompose `aten.pad` op into `aten.constantPadNd` op. class DecomposeAtenPadOp : public OpRewritePattern { @@ -7375,65 +7320,138 @@ public: }; } // namespace -// AtenIndexTensorOp +// Torch ops related to indexing tensors, e.g., AtenIndexTensor, AtenIndexPut. namespace { -// The goal of this pattern is to eliminate none index in aten.Index.Tensor's -// `indices` param for the ease of various backend. The detailed steps are: -// 1. reorder input tensor so that the non-none index appears at adjacent -// positions. -// 2. manually generate index tensor with some ops like iota, to replace the -// none index in `indices` -// 3. replace the old aten.Index.Tensor with a new -// aten.Index.Tensor_hacked_twin. + +// unsqueeze is more easily optimized than a generic view, and we prefer to +// enjoy ops with more structure than less in compositions. +static FailureOr unsqueezeTensorAtTrailingDim(Operation *op, + PatternRewriter &rewriter, + Value input, int count) { + Location loc = op->getLoc(); + Value constMinusOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(-1)); + Value result = input; + while (count--) { + auto unsqzTensorInfo = + unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne); + if (failed(unsqzTensorInfo)) { + return failure(); + } + + result = *unsqzTensorInfo; + } + return result; +} + +static Value createIndexToReplaceNone(Operation *op, PatternRewriter &rewriter, + Value input, int dimInt, + int64_t dimSize) { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value none = rewriter.create(loc); + auto int64Dtype = getDtypeIntValueForType( + rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + + auto resultType = ValueTensorType::get( + context, {dimSize}, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + auto dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimInt)); + auto end = rewriter.create(loc, input, dim); + auto v = rewriter.create( + loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + return v; +} + +static FailureOr createNewIndices(Operation *op, + PatternRewriter &rewriter, Value input, + llvm::ArrayRef oldIndices, + llvm::ArrayRef newToOldDimMap, + llvm::ArrayRef oldIndexUsed) { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + + auto inputType = input.getType().cast(); + if (!inputType.hasSizes()) { + return failure(); + } + auto inputSizes = inputType.getSizes(); + int64_t inputRank = inputSizes.size(); + + int64_t maxIndexRank = 0; + for (auto index : oldIndices) { + auto indexType = index.getType().dyn_cast(); + if (!indexType) // None index + continue; + if (!indexType.hasSizes()) + return failure(); + int64_t indexRank = indexType.getSizes().size(); + maxIndexRank = maxIndexRank > indexRank ? maxIndexRank : indexRank; + } + + // manually generate new indices. + SmallVector listElements(inputRank); + + int64_t noneIndexCnt = 0; + int64_t i; + // handle trailing none indices. + for (i = inputRank - 1; i >= 0; --i) { + int64_t oldI = newToOldDimMap[i]; + if (oldIndexUsed[oldI]) + break; + Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]); + auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, noneIndexCnt); + if (failed(vInfo)) { + return failure(); + } + listElements[i] = *vInfo; + noneIndexCnt++; + } + // handle non-none index in between. + for (; i >= 0; --i) { + int64_t oldI = newToOldDimMap[i]; + if (!oldIndexUsed[oldI]) + break; + auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, oldIndices[oldI], + noneIndexCnt); + if (failed(vInfo)) { + return failure(); + } + listElements[i] = *vInfo; + } + + // handle possible leading none indices. + for (; i >= 0; --i) { + int64_t oldI = newToOldDimMap[i]; + if (oldIndexUsed[oldI]) { + return failure(); + } + Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]); + auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, + noneIndexCnt + maxIndexRank); + if (failed(vInfo)) { + return failure(); + } + listElements[i] = *vInfo; + noneIndexCnt++; + } + + auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); + Value newIndexList = rewriter.create( + loc, Torch::ListType::get(listElemType), listElements); + + return newIndexList; +} + +// The goal of this pattern is to eliminate `None` index in aten.Index.Tensor's +// `indices` param and transform it to aten.index.Tensor_hacked_twin, for the +// ease of various backend. class DecomposeAtenIndexTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - // TODO: It might be better to use aten.view op instead of mulitple - // aten.unsqueeze. But currently, torch-to-linalg pass has limited support for - // view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will - // cause relevant e2e tests fail. - static FailureOr - unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter, - Value input, int count) { - Location loc = op->getLoc(); - Value constMinusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); - Value result = input; - while (count--) { - auto unsqzTensorInfo = - unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne); - if (failed(unsqzTensorInfo)) { - return failure(); - } - - result = *unsqzTensorInfo; - } - return result; - } - - static Value createIndexToReplaceNone(Operation *op, - PatternRewriter &rewriter, Value input, - int dimInt, int64_t dimSize) { - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - Value none = rewriter.create(loc); - auto int64Dtype = getDtypeIntValueForType( - rewriter, loc, - rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); - - auto resultType = ValueTensorType::get( - context, {dimSize}, - rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); - auto dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimInt)); - auto end = rewriter.create(loc, input, dim); - auto v = rewriter.create( - loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); - return v; - } - LogicalResult matchAndRewrite(AtenIndexTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -7451,12 +7469,6 @@ public: } auto inputSizes = inputType.getSizes(); int64_t inputRank = inputSizes.size(); - auto outputType = cast(op.getType()); - if (!outputType.hasSizes()) { - return rewriter.notifyMatchFailure( - op, "only output with shape information is supported"); - } - auto outputRank = outputType.getSizes().size(); auto isTensor = [](Value v) { return v.getType().isa(); @@ -7464,19 +7476,15 @@ public: // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin if (llvm::all_of(indices, isTensor)) { - if (indices.size() == 0) { - return rewriter.notifyMatchFailure( - op, "the indices is empty, it should be folded as a nop"); - } // By default, we regard the first index type as the list element type. auto indexElemType = indices[0] .getType() .template cast() .getWithSizesAndDtype(std::nullopt, nullptr); - auto newIndex = rewriter.create( + auto newIndices = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); - rewriter.replaceOpWithNewOp(op, op.getType(), - input, newIndex); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, newIndices); return success(); } @@ -7484,6 +7492,7 @@ public: llvm::to_vector(llvm::map_range(indices, isTensor)); for (int64_t i = indices.size(); i < inputRank; ++i) indexUsed.emplace_back(false); + bool indexIsConsecutive = true; int64_t firstUsedIndex = -1; for (size_t i = 0; i < indices.size(); ++i) { @@ -7495,17 +7504,15 @@ public: } } - // use aten.permute to reorder the input Value newInput; - // `dims` stores the mapping from new index to the old index of input - // tensor. - SmallVector dims; + SmallVector newToOldDimMap; + // permute input to make the non-none indices consecutive. if (!indexIsConsecutive) { SmallVector dimValues; SmallVector permutedSizes; for (int i = 0; i < inputRank; i++) { if (indexUsed[i]) { - dims.emplace_back(i); + newToOldDimMap.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); @@ -7513,7 +7520,7 @@ public: } for (int i = 0; i < inputRank; i++) { if (!indexUsed[i]) { - dims.emplace_back(i); + newToOldDimMap.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); @@ -7529,66 +7536,100 @@ public: } else { newInput = input; for (int i = 0; i < inputRank; i++) { - dims.emplace_back(i); + newToOldDimMap.emplace_back(i); } } - // manually generate new indices. - SmallVector listElements(inputRank); - - int64_t trailingDimCnt = 0; - int64_t i; - // handle trailing none index. - for (i = inputRank - 1; i >= 0; --i) { - int64_t oldI = dims[i]; - if (indexUsed[oldI]) - break; - Value v = - createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); - auto vInfo = - unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt); - if (failed(vInfo)) { - return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); - } - listElements[i] = *vInfo; - trailingDimCnt++; + auto newIndeicesInfo = createNewIndices(op, rewriter, newInput, indices, + newToOldDimMap, indexUsed); + if (failed(newIndeicesInfo)) { + return rewriter.notifyMatchFailure(op, "failed to replcae `None` index"); } - // handle non-none index in between. - for (; i >= 0; --i) { - int64_t oldI = dims[i]; - if (!indexUsed[oldI]) - break; - auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI], - trailingDimCnt); - if (failed(vInfo)) { - return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); - } - listElements[i] = *vInfo; - } - - // handle possible leading none dimensions. - for (; i >= 0; --i) { - int64_t oldI = dims[i]; - if (indexUsed[oldI]) { - return rewriter.notifyMatchFailure( - op, "the indices are still unconsecutive after reordering input " - "tensor"); - } - Value v = - createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); - auto vInfo = - unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i); - if (failed(vInfo)) { - return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); - } - listElements[i] = *vInfo; - } - - auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); - auto newIndexList = rewriter.create( - loc, Torch::ListType::get(listElemType), listElements); rewriter.replaceOpWithNewOp( - op, op.getType(), newInput, newIndexList); + op, op.getType(), newInput, *newIndeicesInfo); + return success(); + } +}; + +// The goal of this pattern is to eliminate `None` index in aten.inde_put-like +// ops' `indices` param and transform it to aten.index_put.hacked_twin, for the +// ease of various backend. +template +class DecomposeAtenIndexPutLikeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtenIndexPutLikeOpT op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return rewriter.notifyMatchFailure(op, + "failed to get elements of `indices`"); + + auto input = op.getSelf(); + auto inputType = input.getType().template cast(); + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "only input with shape information is supported"); + } + auto inputSizes = inputType.getSizes(); + int64_t inputRank = inputSizes.size(); + + auto isTensor = [](Value v) { + return v.getType().isa(); + }; + + // directly replace current op with aten.index_put.hacked_twin + if (llvm::all_of(indices, isTensor)) { + // By default, we regard the first index type as the list element type. + auto indexElemType = indices[0] + .getType() + .template cast() + .getWithSizesAndDtype(std::nullopt, nullptr); + auto newIndex = rewriter.create( + loc, Torch::ListType::get(indexElemType), indices); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, newIndex, op.getValues(), + op.getAccumulate()); + return success(); + } + + SmallVector indexUsed = + llvm::to_vector(llvm::map_range(indices, isTensor)); + for (int64_t i = indices.size(); i < inputRank; ++i) + indexUsed.emplace_back(false); + + // check if non-None index is consecutive + bool indexIsConsecutive = true; + int64_t firstUsedIndex = -1; + for (size_t i = 0; i < indices.size(); ++i) { + if (indexUsed[i] && firstUsedIndex == -1) { + firstUsedIndex = i; + } else if (indexUsed[i] && !indexUsed[i - 1]) { + indexIsConsecutive = false; + break; + } + } + if (!indexIsConsecutive) { + return rewriter.notifyMatchFailure( + op, "non consecutive indices is not supported"); + } + + SmallVector newToOldDimMap; + for (int i = 0; i < inputRank; i++) { + newToOldDimMap.emplace_back(i); + } + + auto newIndicesInfo = createNewIndices(op, rewriter, input, indices, + newToOldDimMap, indexUsed); + if (failed(newIndicesInfo)) { + return rewriter.notifyMatchFailure(op, "failed to replace `None` index"); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), input, *newIndicesInfo, op.getValues(), + op.getAccumulate()); return success(); } }; @@ -7881,16 +7922,19 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal( + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal>( patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenIndexPutLikeOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenIndexPutLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7956,7 +8000,6 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); // More specific conv ops diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 36dc40711..3981cff44 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -466,13 +466,14 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -500,7 +501,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c3a300f36..26ff502e4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1704,7 +1704,6 @@ TOSA_PASS_SET = { "HardswishModule_basic", "HardswishRandomModule_basic", "HardtanhBackward_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", "IndexTensorStaticModule_basic", "IscloseStaticModuleTrue_basic",