[Torch Dialect] decompose all index_put-like op to aten.index_put.hacked_twin for stricter semantics (#3071)

This PR decomposes all index_put-like op to aten.index_put.hacked_twin for stricter semantics, i.e., no None index in indices argument.
pull/3300/head
Jiawei Wu 2024-05-08 22:44:57 +08:00 committed by GitHub
parent abef114c0c
commit 346a536c9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 241 additions and 210 deletions

View File

@ -675,12 +675,12 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
}
class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
class ConvertAtenIndexPutHackedTwinOp
: public OpConversionPattern<AtenIndexPutHackedTwinOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor,
matchAndRewrite(AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
@ -699,17 +699,6 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: the values tensor type must have sizes.");
// The unsafe should be either `False` or `none`.
if (!op.getUnsafe().getType().isa<Torch::NoneType>()) {
bool unsafe;
if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe)))
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe must be a constant");
else if (unsafe)
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe is expected to be false");
}
// The accumulate should be a torch constant of boolean type.
bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
@ -1621,8 +1610,8 @@ public:
RewritePatternSet patterns(context);
target.addIllegalOp<AtenBincountOp>();
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
target.addIllegalOp<Aten_IndexPutImplOp>();
patterns.add<ConvertAten_IndexPutImplOp>(typeConverter, context);
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
patterns.add<ConvertAtenIndexPutHackedTwinOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
context);

View File

@ -3575,8 +3575,8 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
}
template <>
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
Aten_IndexPutImplOp op, OpAdaptor adaptor,
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// a = torch.tensor([[0, 1, 2, 3]])
// a[..., 1:] = torch.tensor([4, 5, 6])
@ -5331,7 +5331,7 @@ public:
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenAbsOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);

View File

@ -5523,23 +5523,6 @@ public:
};
} // namespace
namespace {
// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op.
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexPutOp op,
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
using OpRewritePattern::OpRewritePattern;
@ -5635,44 +5618,6 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
};
} // namespace
namespace {
// Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl`
// op.
class DecomposeAtenIndexPutHackedTwinOp
: public OpRewritePattern<AtenIndexPutHackedTwinOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op,
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse);
return success();
}
};
} // namespace
namespace {
// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl`
// op.
class DecomposeAten_UnsafeIndexPutHackedTwinOp
: public OpRewritePattern<Aten_UnsafeIndexPutHackedTwinOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op,
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse);
return success();
}
};
} // namespace
namespace {
// Decompose `aten.pad` op into `aten.constantPadNd` op.
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
@ -7375,65 +7320,138 @@ public:
};
} // namespace
// AtenIndexTensorOp
// Torch ops related to indexing tensors, e.g., AtenIndexTensor, AtenIndexPut.
namespace {
// The goal of this pattern is to eliminate none index in aten.Index.Tensor's
// `indices` param for the ease of various backend. The detailed steps are:
// 1. reorder input tensor so that the non-none index appears at adjacent
// positions.
// 2. manually generate index tensor with some ops like iota, to replace the
// none index in `indices`
// 3. replace the old aten.Index.Tensor with a new
// aten.Index.Tensor_hacked_twin.
// unsqueeze is more easily optimized than a generic view, and we prefer to
// enjoy ops with more structure than less in compositions.
static FailureOr<Value> unsqueezeTensorAtTrailingDim(Operation *op,
PatternRewriter &rewriter,
Value input, int count) {
Location loc = op->getLoc();
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value result = input;
while (count--) {
auto unsqzTensorInfo =
unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne);
if (failed(unsqzTensorInfo)) {
return failure();
}
result = *unsqzTensorInfo;
}
return result;
}
static Value createIndexToReplaceNone(Operation *op, PatternRewriter &rewriter,
Value input, int dimInt,
int64_t dimSize) {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto resultType = ValueTensorType::get(
context, {dimSize},
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimInt));
auto end = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
auto v = rewriter.create<Torch::AtenArangeOp>(
loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
return v;
}
static FailureOr<Value> createNewIndices(Operation *op,
PatternRewriter &rewriter, Value input,
llvm::ArrayRef<Value> oldIndices,
llvm::ArrayRef<int64_t> newToOldDimMap,
llvm::ArrayRef<bool> oldIndexUsed) {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return failure();
}
auto inputSizes = inputType.getSizes();
int64_t inputRank = inputSizes.size();
int64_t maxIndexRank = 0;
for (auto index : oldIndices) {
auto indexType = index.getType().dyn_cast<BaseTensorType>();
if (!indexType) // None index
continue;
if (!indexType.hasSizes())
return failure();
int64_t indexRank = indexType.getSizes().size();
maxIndexRank = maxIndexRank > indexRank ? maxIndexRank : indexRank;
}
// manually generate new indices.
SmallVector<Value> listElements(inputRank);
int64_t noneIndexCnt = 0;
int64_t i;
// handle trailing none indices.
for (i = inputRank - 1; i >= 0; --i) {
int64_t oldI = newToOldDimMap[i];
if (oldIndexUsed[oldI])
break;
Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]);
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, noneIndexCnt);
if (failed(vInfo)) {
return failure();
}
listElements[i] = *vInfo;
noneIndexCnt++;
}
// handle non-none index in between.
for (; i >= 0; --i) {
int64_t oldI = newToOldDimMap[i];
if (!oldIndexUsed[oldI])
break;
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, oldIndices[oldI],
noneIndexCnt);
if (failed(vInfo)) {
return failure();
}
listElements[i] = *vInfo;
}
// handle possible leading none indices.
for (; i >= 0; --i) {
int64_t oldI = newToOldDimMap[i];
if (oldIndexUsed[oldI]) {
return failure();
}
Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]);
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v,
noneIndexCnt + maxIndexRank);
if (failed(vInfo)) {
return failure();
}
listElements[i] = *vInfo;
noneIndexCnt++;
}
auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr);
Value newIndexList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(listElemType), listElements);
return newIndexList;
}
// The goal of this pattern is to eliminate `None` index in aten.Index.Tensor's
// `indices` param and transform it to aten.index.Tensor_hacked_twin, for the
// ease of various backend.
class DecomposeAtenIndexTensorOp : public OpRewritePattern<AtenIndexTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
// TODO: It might be better to use aten.view op instead of mulitple
// aten.unsqueeze. But currently, torch-to-linalg pass has limited support for
// view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will
// cause relevant e2e tests fail.
static FailureOr<Value>
unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter,
Value input, int count) {
Location loc = op->getLoc();
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value result = input;
while (count--) {
auto unsqzTensorInfo =
unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne);
if (failed(unsqzTensorInfo)) {
return failure();
}
result = *unsqzTensorInfo;
}
return result;
}
static Value createIndexToReplaceNone(Operation *op,
PatternRewriter &rewriter, Value input,
int dimInt, int64_t dimSize) {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto resultType = ValueTensorType::get(
context, {dimSize},
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimInt));
auto end = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
auto v = rewriter.create<Torch::AtenArangeOp>(
loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
return v;
}
LogicalResult matchAndRewrite(AtenIndexTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
@ -7451,12 +7469,6 @@ public:
}
auto inputSizes = inputType.getSizes();
int64_t inputRank = inputSizes.size();
auto outputType = cast<BaseTensorType>(op.getType());
if (!outputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "only output with shape information is supported");
}
auto outputRank = outputType.getSizes().size();
auto isTensor = [](Value v) {
return v.getType().isa<Torch::BaseTensorType>();
@ -7464,19 +7476,15 @@ public:
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
if (llvm::all_of(indices, isTensor)) {
if (indices.size() == 0) {
return rewriter.notifyMatchFailure(
op, "the indices is empty, it should be folded as a nop");
}
// By default, we regard the first index type as the list element type.
auto indexElemType = indices[0]
.getType()
.template cast<BaseTensorType>()
.getWithSizesAndDtype(std::nullopt, nullptr);
auto newIndex = rewriter.create<PrimListConstructOp>(
auto newIndices = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(indexElemType), indices);
rewriter.replaceOpWithNewOp<AtenIndexTensorHackedTwinOp>(op, op.getType(),
input, newIndex);
rewriter.replaceOpWithNewOp<AtenIndexTensorHackedTwinOp>(
op, op.getType(), input, newIndices);
return success();
}
@ -7484,6 +7492,7 @@ public:
llvm::to_vector(llvm::map_range(indices, isTensor));
for (int64_t i = indices.size(); i < inputRank; ++i)
indexUsed.emplace_back(false);
bool indexIsConsecutive = true;
int64_t firstUsedIndex = -1;
for (size_t i = 0; i < indices.size(); ++i) {
@ -7495,17 +7504,15 @@ public:
}
}
// use aten.permute to reorder the input
Value newInput;
// `dims` stores the mapping from new index to the old index of input
// tensor.
SmallVector<int64_t> dims;
SmallVector<int64_t> newToOldDimMap;
// permute input to make the non-none indices consecutive.
if (!indexIsConsecutive) {
SmallVector<Value> dimValues;
SmallVector<int64_t> permutedSizes;
for (int i = 0; i < inputRank; i++) {
if (indexUsed[i]) {
dims.emplace_back(i);
newToOldDimMap.emplace_back(i);
dimValues.emplace_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
permutedSizes.emplace_back(inputSizes[i]);
@ -7513,7 +7520,7 @@ public:
}
for (int i = 0; i < inputRank; i++) {
if (!indexUsed[i]) {
dims.emplace_back(i);
newToOldDimMap.emplace_back(i);
dimValues.emplace_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
permutedSizes.emplace_back(inputSizes[i]);
@ -7529,66 +7536,100 @@ public:
} else {
newInput = input;
for (int i = 0; i < inputRank; i++) {
dims.emplace_back(i);
newToOldDimMap.emplace_back(i);
}
}
// manually generate new indices.
SmallVector<Value> listElements(inputRank);
int64_t trailingDimCnt = 0;
int64_t i;
// handle trailing none index.
for (i = inputRank - 1; i >= 0; --i) {
int64_t oldI = dims[i];
if (indexUsed[oldI])
break;
Value v =
createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]);
auto vInfo =
unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt);
if (failed(vInfo)) {
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
}
listElements[i] = *vInfo;
trailingDimCnt++;
auto newIndeicesInfo = createNewIndices(op, rewriter, newInput, indices,
newToOldDimMap, indexUsed);
if (failed(newIndeicesInfo)) {
return rewriter.notifyMatchFailure(op, "failed to replcae `None` index");
}
// handle non-none index in between.
for (; i >= 0; --i) {
int64_t oldI = dims[i];
if (!indexUsed[oldI])
break;
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI],
trailingDimCnt);
if (failed(vInfo)) {
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
}
listElements[i] = *vInfo;
}
// handle possible leading none dimensions.
for (; i >= 0; --i) {
int64_t oldI = dims[i];
if (indexUsed[oldI]) {
return rewriter.notifyMatchFailure(
op, "the indices are still unconsecutive after reordering input "
"tensor");
}
Value v =
createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]);
auto vInfo =
unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i);
if (failed(vInfo)) {
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
}
listElements[i] = *vInfo;
}
auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr);
auto newIndexList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(listElemType), listElements);
rewriter.replaceOpWithNewOp<Torch::AtenIndexTensorHackedTwinOp>(
op, op.getType(), newInput, newIndexList);
op, op.getType(), newInput, *newIndeicesInfo);
return success();
}
};
// The goal of this pattern is to eliminate `None` index in aten.inde_put-like
// ops' `indices` param and transform it to aten.index_put.hacked_twin, for the
// ease of various backend.
template <typename AtenIndexPutLikeOpT>
class DecomposeAtenIndexPutLikeOp
: public OpRewritePattern<AtenIndexPutLikeOpT> {
public:
using OpRewritePattern<AtenIndexPutLikeOpT>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexPutLikeOpT op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<Value> indices;
if (!getListConstructElements(op.getIndices(), indices))
return rewriter.notifyMatchFailure(op,
"failed to get elements of `indices`");
auto input = op.getSelf();
auto inputType = input.getType().template cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "only input with shape information is supported");
}
auto inputSizes = inputType.getSizes();
int64_t inputRank = inputSizes.size();
auto isTensor = [](Value v) {
return v.getType().isa<Torch::BaseTensorType>();
};
// directly replace current op with aten.index_put.hacked_twin
if (llvm::all_of(indices, isTensor)) {
// By default, we regard the first index type as the list element type.
auto indexElemType = indices[0]
.getType()
.template cast<BaseTensorType>()
.getWithSizesAndDtype(std::nullopt, nullptr);
auto newIndex = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(indexElemType), indices);
rewriter.replaceOpWithNewOp<AtenIndexPutHackedTwinOp>(
op, op.getType(), input, newIndex, op.getValues(),
op.getAccumulate());
return success();
}
SmallVector<bool> indexUsed =
llvm::to_vector(llvm::map_range(indices, isTensor));
for (int64_t i = indices.size(); i < inputRank; ++i)
indexUsed.emplace_back(false);
// check if non-None index is consecutive
bool indexIsConsecutive = true;
int64_t firstUsedIndex = -1;
for (size_t i = 0; i < indices.size(); ++i) {
if (indexUsed[i] && firstUsedIndex == -1) {
firstUsedIndex = i;
} else if (indexUsed[i] && !indexUsed[i - 1]) {
indexIsConsecutive = false;
break;
}
}
if (!indexIsConsecutive) {
return rewriter.notifyMatchFailure(
op, "non consecutive indices is not supported");
}
SmallVector<int64_t> newToOldDimMap;
for (int i = 0; i < inputRank; i++) {
newToOldDimMap.emplace_back(i);
}
auto newIndicesInfo = createNewIndices(op, rewriter, input, indices,
newToOldDimMap, indexUsed);
if (failed(newIndicesInfo)) {
return rewriter.notifyMatchFailure(op, "failed to replace `None` index");
}
rewriter.replaceOpWithNewOp<AtenIndexPutHackedTwinOp>(
op, op.getType(), input, *newIndicesInfo, op.getValues(),
op.getAccumulate());
return success();
}
};
@ -7881,16 +7922,19 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutLikeOp<AtenIndexPutOp>>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenIndexPutLikeOp<Aten_UnsafeIndexPutHackedTwinOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenIndexPutLikeOp<Aten_IndexPutImplOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
@ -7956,7 +8000,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
// More specific conv ops

View File

@ -466,13 +466,14 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenNewFullOp>();
target.addIllegalOp<AtenIndexPutOp>();
target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>();
target.addIllegalOp<AtenDropoutOp>();
target.addIllegalOp<AtenNativeDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>();
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
target.addIllegalOp<AtenIndexTensorOp>();
target.addIllegalOp<AtenIndexPutOp>();
target.addIllegalOp<Aten_IndexPutImplOp>();
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenPreluOp>();
@ -500,7 +501,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenLerpScalarOp>();
target.addIllegalOp<AtenIndexTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>();

View File

@ -1704,7 +1704,6 @@ TOSA_PASS_SET = {
"HardswishModule_basic",
"HardswishRandomModule_basic",
"HardtanhBackward_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorStaticModule_basic",
"IscloseStaticModuleTrue_basic",