mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] decompose all index_put-like op to aten.index_put.hacked_twin for stricter semantics (#3071)
This PR decomposes all index_put-like op to aten.index_put.hacked_twin for stricter semantics, i.e., no None index in indices argument.pull/3300/head
parent
abef114c0c
commit
346a536c9f
|
@ -675,12 +675,12 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
|
|||
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
|
||||
}
|
||||
|
||||
class ConvertAten_IndexPutImplOp
|
||||
: public OpConversionPattern<Aten_IndexPutImplOp> {
|
||||
class ConvertAtenIndexPutHackedTwinOp
|
||||
: public OpConversionPattern<AtenIndexPutHackedTwinOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
@ -699,17 +699,6 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the values tensor type must have sizes.");
|
||||
|
||||
// The unsafe should be either `False` or `none`.
|
||||
if (!op.getUnsafe().getType().isa<Torch::NoneType>()) {
|
||||
bool unsafe;
|
||||
if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: unsafe must be a constant");
|
||||
else if (unsafe)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: unsafe is expected to be false");
|
||||
}
|
||||
|
||||
// The accumulate should be a torch constant of boolean type.
|
||||
bool accumulate;
|
||||
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
|
||||
|
@ -1621,8 +1610,8 @@ public:
|
|||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<AtenBincountOp>();
|
||||
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
|
||||
target.addIllegalOp<Aten_IndexPutImplOp>();
|
||||
patterns.add<ConvertAten_IndexPutImplOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
||||
patterns.add<ConvertAtenIndexPutHackedTwinOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
|
||||
context);
|
||||
|
|
|
@ -3575,8 +3575,8 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
||||
Aten_IndexPutImplOp op, OpAdaptor adaptor,
|
||||
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// a = torch.tensor([[0, 1, 2, 3]])
|
||||
// a[..., 1:] = torch.tensor([4, 5, 6])
|
||||
|
@ -5331,7 +5331,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAbsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
|
|
|
@ -5523,23 +5523,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op.
|
||||
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenIndexPutOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
||||
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
||||
op.getAccumulate(),
|
||||
/*unsafe=*/cstFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
@ -5635,44 +5618,6 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl`
|
||||
// op.
|
||||
class DecomposeAtenIndexPutHackedTwinOp
|
||||
: public OpRewritePattern<AtenIndexPutHackedTwinOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
||||
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
||||
op.getAccumulate(),
|
||||
/*unsafe=*/cstFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl`
|
||||
// op.
|
||||
class DecomposeAten_UnsafeIndexPutHackedTwinOp
|
||||
: public OpRewritePattern<Aten_UnsafeIndexPutHackedTwinOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
||||
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
||||
op.getAccumulate(),
|
||||
/*unsafe=*/cstFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.pad` op into `aten.constantPadNd` op.
|
||||
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
||||
|
@ -7375,65 +7320,138 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// AtenIndexTensorOp
|
||||
// Torch ops related to indexing tensors, e.g., AtenIndexTensor, AtenIndexPut.
|
||||
namespace {
|
||||
// The goal of this pattern is to eliminate none index in aten.Index.Tensor's
|
||||
// `indices` param for the ease of various backend. The detailed steps are:
|
||||
// 1. reorder input tensor so that the non-none index appears at adjacent
|
||||
// positions.
|
||||
// 2. manually generate index tensor with some ops like iota, to replace the
|
||||
// none index in `indices`
|
||||
// 3. replace the old aten.Index.Tensor with a new
|
||||
// aten.Index.Tensor_hacked_twin.
|
||||
|
||||
// unsqueeze is more easily optimized than a generic view, and we prefer to
|
||||
// enjoy ops with more structure than less in compositions.
|
||||
static FailureOr<Value> unsqueezeTensorAtTrailingDim(Operation *op,
|
||||
PatternRewriter &rewriter,
|
||||
Value input, int count) {
|
||||
Location loc = op->getLoc();
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
Value result = input;
|
||||
while (count--) {
|
||||
auto unsqzTensorInfo =
|
||||
unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne);
|
||||
if (failed(unsqzTensorInfo)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
result = *unsqzTensorInfo;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static Value createIndexToReplaceNone(Operation *op, PatternRewriter &rewriter,
|
||||
Value input, int dimInt,
|
||||
int64_t dimSize) {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
auto int64Dtype = getDtypeIntValueForType(
|
||||
rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||
|
||||
auto resultType = ValueTensorType::get(
|
||||
context, {dimSize},
|
||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||
auto dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimInt));
|
||||
auto end = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
||||
auto v = rewriter.create<Torch::AtenArangeOp>(
|
||||
loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
return v;
|
||||
}
|
||||
|
||||
static FailureOr<Value> createNewIndices(Operation *op,
|
||||
PatternRewriter &rewriter, Value input,
|
||||
llvm::ArrayRef<Value> oldIndices,
|
||||
llvm::ArrayRef<int64_t> newToOldDimMap,
|
||||
llvm::ArrayRef<bool> oldIndexUsed) {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes()) {
|
||||
return failure();
|
||||
}
|
||||
auto inputSizes = inputType.getSizes();
|
||||
int64_t inputRank = inputSizes.size();
|
||||
|
||||
int64_t maxIndexRank = 0;
|
||||
for (auto index : oldIndices) {
|
||||
auto indexType = index.getType().dyn_cast<BaseTensorType>();
|
||||
if (!indexType) // None index
|
||||
continue;
|
||||
if (!indexType.hasSizes())
|
||||
return failure();
|
||||
int64_t indexRank = indexType.getSizes().size();
|
||||
maxIndexRank = maxIndexRank > indexRank ? maxIndexRank : indexRank;
|
||||
}
|
||||
|
||||
// manually generate new indices.
|
||||
SmallVector<Value> listElements(inputRank);
|
||||
|
||||
int64_t noneIndexCnt = 0;
|
||||
int64_t i;
|
||||
// handle trailing none indices.
|
||||
for (i = inputRank - 1; i >= 0; --i) {
|
||||
int64_t oldI = newToOldDimMap[i];
|
||||
if (oldIndexUsed[oldI])
|
||||
break;
|
||||
Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]);
|
||||
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, noneIndexCnt);
|
||||
if (failed(vInfo)) {
|
||||
return failure();
|
||||
}
|
||||
listElements[i] = *vInfo;
|
||||
noneIndexCnt++;
|
||||
}
|
||||
// handle non-none index in between.
|
||||
for (; i >= 0; --i) {
|
||||
int64_t oldI = newToOldDimMap[i];
|
||||
if (!oldIndexUsed[oldI])
|
||||
break;
|
||||
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, oldIndices[oldI],
|
||||
noneIndexCnt);
|
||||
if (failed(vInfo)) {
|
||||
return failure();
|
||||
}
|
||||
listElements[i] = *vInfo;
|
||||
}
|
||||
|
||||
// handle possible leading none indices.
|
||||
for (; i >= 0; --i) {
|
||||
int64_t oldI = newToOldDimMap[i];
|
||||
if (oldIndexUsed[oldI]) {
|
||||
return failure();
|
||||
}
|
||||
Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]);
|
||||
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v,
|
||||
noneIndexCnt + maxIndexRank);
|
||||
if (failed(vInfo)) {
|
||||
return failure();
|
||||
}
|
||||
listElements[i] = *vInfo;
|
||||
noneIndexCnt++;
|
||||
}
|
||||
|
||||
auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr);
|
||||
Value newIndexList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(listElemType), listElements);
|
||||
|
||||
return newIndexList;
|
||||
}
|
||||
|
||||
// The goal of this pattern is to eliminate `None` index in aten.Index.Tensor's
|
||||
// `indices` param and transform it to aten.index.Tensor_hacked_twin, for the
|
||||
// ease of various backend.
|
||||
class DecomposeAtenIndexTensorOp : public OpRewritePattern<AtenIndexTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
// TODO: It might be better to use aten.view op instead of mulitple
|
||||
// aten.unsqueeze. But currently, torch-to-linalg pass has limited support for
|
||||
// view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will
|
||||
// cause relevant e2e tests fail.
|
||||
static FailureOr<Value>
|
||||
unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter,
|
||||
Value input, int count) {
|
||||
Location loc = op->getLoc();
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
Value result = input;
|
||||
while (count--) {
|
||||
auto unsqzTensorInfo =
|
||||
unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne);
|
||||
if (failed(unsqzTensorInfo)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
result = *unsqzTensorInfo;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static Value createIndexToReplaceNone(Operation *op,
|
||||
PatternRewriter &rewriter, Value input,
|
||||
int dimInt, int64_t dimSize) {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
auto int64Dtype = getDtypeIntValueForType(
|
||||
rewriter, loc,
|
||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||
|
||||
auto resultType = ValueTensorType::get(
|
||||
context, {dimSize},
|
||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||
auto dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimInt));
|
||||
auto end = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
||||
auto v = rewriter.create<Torch::AtenArangeOp>(
|
||||
loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
return v;
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(AtenIndexTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
@ -7451,12 +7469,6 @@ public:
|
|||
}
|
||||
auto inputSizes = inputType.getSizes();
|
||||
int64_t inputRank = inputSizes.size();
|
||||
auto outputType = cast<BaseTensorType>(op.getType());
|
||||
if (!outputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only output with shape information is supported");
|
||||
}
|
||||
auto outputRank = outputType.getSizes().size();
|
||||
|
||||
auto isTensor = [](Value v) {
|
||||
return v.getType().isa<Torch::BaseTensorType>();
|
||||
|
@ -7464,19 +7476,15 @@ public:
|
|||
|
||||
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
|
||||
if (llvm::all_of(indices, isTensor)) {
|
||||
if (indices.size() == 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the indices is empty, it should be folded as a nop");
|
||||
}
|
||||
// By default, we regard the first index type as the list element type.
|
||||
auto indexElemType = indices[0]
|
||||
.getType()
|
||||
.template cast<BaseTensorType>()
|
||||
.getWithSizesAndDtype(std::nullopt, nullptr);
|
||||
auto newIndex = rewriter.create<PrimListConstructOp>(
|
||||
auto newIndices = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(indexElemType), indices);
|
||||
rewriter.replaceOpWithNewOp<AtenIndexTensorHackedTwinOp>(op, op.getType(),
|
||||
input, newIndex);
|
||||
rewriter.replaceOpWithNewOp<AtenIndexTensorHackedTwinOp>(
|
||||
op, op.getType(), input, newIndices);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -7484,6 +7492,7 @@ public:
|
|||
llvm::to_vector(llvm::map_range(indices, isTensor));
|
||||
for (int64_t i = indices.size(); i < inputRank; ++i)
|
||||
indexUsed.emplace_back(false);
|
||||
|
||||
bool indexIsConsecutive = true;
|
||||
int64_t firstUsedIndex = -1;
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
|
@ -7495,17 +7504,15 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// use aten.permute to reorder the input
|
||||
Value newInput;
|
||||
// `dims` stores the mapping from new index to the old index of input
|
||||
// tensor.
|
||||
SmallVector<int64_t> dims;
|
||||
SmallVector<int64_t> newToOldDimMap;
|
||||
// permute input to make the non-none indices consecutive.
|
||||
if (!indexIsConsecutive) {
|
||||
SmallVector<Value> dimValues;
|
||||
SmallVector<int64_t> permutedSizes;
|
||||
for (int i = 0; i < inputRank; i++) {
|
||||
if (indexUsed[i]) {
|
||||
dims.emplace_back(i);
|
||||
newToOldDimMap.emplace_back(i);
|
||||
dimValues.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
permutedSizes.emplace_back(inputSizes[i]);
|
||||
|
@ -7513,7 +7520,7 @@ public:
|
|||
}
|
||||
for (int i = 0; i < inputRank; i++) {
|
||||
if (!indexUsed[i]) {
|
||||
dims.emplace_back(i);
|
||||
newToOldDimMap.emplace_back(i);
|
||||
dimValues.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
permutedSizes.emplace_back(inputSizes[i]);
|
||||
|
@ -7529,66 +7536,100 @@ public:
|
|||
} else {
|
||||
newInput = input;
|
||||
for (int i = 0; i < inputRank; i++) {
|
||||
dims.emplace_back(i);
|
||||
newToOldDimMap.emplace_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// manually generate new indices.
|
||||
SmallVector<Value> listElements(inputRank);
|
||||
|
||||
int64_t trailingDimCnt = 0;
|
||||
int64_t i;
|
||||
// handle trailing none index.
|
||||
for (i = inputRank - 1; i >= 0; --i) {
|
||||
int64_t oldI = dims[i];
|
||||
if (indexUsed[oldI])
|
||||
break;
|
||||
Value v =
|
||||
createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]);
|
||||
auto vInfo =
|
||||
unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt);
|
||||
if (failed(vInfo)) {
|
||||
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
|
||||
}
|
||||
listElements[i] = *vInfo;
|
||||
trailingDimCnt++;
|
||||
auto newIndeicesInfo = createNewIndices(op, rewriter, newInput, indices,
|
||||
newToOldDimMap, indexUsed);
|
||||
if (failed(newIndeicesInfo)) {
|
||||
return rewriter.notifyMatchFailure(op, "failed to replcae `None` index");
|
||||
}
|
||||
// handle non-none index in between.
|
||||
for (; i >= 0; --i) {
|
||||
int64_t oldI = dims[i];
|
||||
if (!indexUsed[oldI])
|
||||
break;
|
||||
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI],
|
||||
trailingDimCnt);
|
||||
if (failed(vInfo)) {
|
||||
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
|
||||
}
|
||||
listElements[i] = *vInfo;
|
||||
}
|
||||
|
||||
// handle possible leading none dimensions.
|
||||
for (; i >= 0; --i) {
|
||||
int64_t oldI = dims[i];
|
||||
if (indexUsed[oldI]) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the indices are still unconsecutive after reordering input "
|
||||
"tensor");
|
||||
}
|
||||
Value v =
|
||||
createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]);
|
||||
auto vInfo =
|
||||
unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i);
|
||||
if (failed(vInfo)) {
|
||||
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
|
||||
}
|
||||
listElements[i] = *vInfo;
|
||||
}
|
||||
|
||||
auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr);
|
||||
auto newIndexList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(listElemType), listElements);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenIndexTensorHackedTwinOp>(
|
||||
op, op.getType(), newInput, newIndexList);
|
||||
op, op.getType(), newInput, *newIndeicesInfo);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// The goal of this pattern is to eliminate `None` index in aten.inde_put-like
|
||||
// ops' `indices` param and transform it to aten.index_put.hacked_twin, for the
|
||||
// ease of various backend.
|
||||
template <typename AtenIndexPutLikeOpT>
|
||||
class DecomposeAtenIndexPutLikeOp
|
||||
: public OpRewritePattern<AtenIndexPutLikeOpT> {
|
||||
public:
|
||||
using OpRewritePattern<AtenIndexPutLikeOpT>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AtenIndexPutLikeOpT op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> indices;
|
||||
if (!getListConstructElements(op.getIndices(), indices))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"failed to get elements of `indices`");
|
||||
|
||||
auto input = op.getSelf();
|
||||
auto inputType = input.getType().template cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only input with shape information is supported");
|
||||
}
|
||||
auto inputSizes = inputType.getSizes();
|
||||
int64_t inputRank = inputSizes.size();
|
||||
|
||||
auto isTensor = [](Value v) {
|
||||
return v.getType().isa<Torch::BaseTensorType>();
|
||||
};
|
||||
|
||||
// directly replace current op with aten.index_put.hacked_twin
|
||||
if (llvm::all_of(indices, isTensor)) {
|
||||
// By default, we regard the first index type as the list element type.
|
||||
auto indexElemType = indices[0]
|
||||
.getType()
|
||||
.template cast<BaseTensorType>()
|
||||
.getWithSizesAndDtype(std::nullopt, nullptr);
|
||||
auto newIndex = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(indexElemType), indices);
|
||||
rewriter.replaceOpWithNewOp<AtenIndexPutHackedTwinOp>(
|
||||
op, op.getType(), input, newIndex, op.getValues(),
|
||||
op.getAccumulate());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<bool> indexUsed =
|
||||
llvm::to_vector(llvm::map_range(indices, isTensor));
|
||||
for (int64_t i = indices.size(); i < inputRank; ++i)
|
||||
indexUsed.emplace_back(false);
|
||||
|
||||
// check if non-None index is consecutive
|
||||
bool indexIsConsecutive = true;
|
||||
int64_t firstUsedIndex = -1;
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
if (indexUsed[i] && firstUsedIndex == -1) {
|
||||
firstUsedIndex = i;
|
||||
} else if (indexUsed[i] && !indexUsed[i - 1]) {
|
||||
indexIsConsecutive = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!indexIsConsecutive) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non consecutive indices is not supported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> newToOldDimMap;
|
||||
for (int i = 0; i < inputRank; i++) {
|
||||
newToOldDimMap.emplace_back(i);
|
||||
}
|
||||
|
||||
auto newIndicesInfo = createNewIndices(op, rewriter, input, indices,
|
||||
newToOldDimMap, indexUsed);
|
||||
if (failed(newIndicesInfo)) {
|
||||
return rewriter.notifyMatchFailure(op, "failed to replace `None` index");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenIndexPutHackedTwinOp>(
|
||||
op, op.getType(), input, *newIndicesInfo, op.getValues(),
|
||||
op.getAccumulate());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -7881,16 +7922,19 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutLikeOp<AtenIndexPutOp>>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenIndexPutLikeOp<Aten_UnsafeIndexPutHackedTwinOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenIndexPutLikeOp<Aten_IndexPutImplOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||
|
@ -7956,7 +8000,6 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
|
||||
// More specific conv ops
|
||||
|
|
|
@ -466,13 +466,14 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenMishOp>();
|
||||
target.addIllegalOp<AtenFullLikeOp>();
|
||||
target.addIllegalOp<AtenNewFullOp>();
|
||||
target.addIllegalOp<AtenIndexPutOp>();
|
||||
target.addIllegalOp<AtenExpandAsOp>();
|
||||
target.addIllegalOp<Aten_ToCopyOp>();
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
target.addIllegalOp<AtenNativeDropoutOp>();
|
||||
target.addIllegalOp<AtenNewEmptyOp>();
|
||||
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenIndexTensorOp>();
|
||||
target.addIllegalOp<AtenIndexPutOp>();
|
||||
target.addIllegalOp<Aten_IndexPutImplOp>();
|
||||
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenPadOp>();
|
||||
target.addIllegalOp<AtenPreluOp>();
|
||||
|
@ -500,7 +501,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||
target.addIllegalOp<AtenLerpScalarOp>();
|
||||
target.addIllegalOp<AtenIndexTensorOp>();
|
||||
target.addIllegalOp<AtenMseLossOp>();
|
||||
target.addIllegalOp<AtenRandintLowOp>();
|
||||
target.addIllegalOp<AtenRandintOp>();
|
||||
|
|
|
@ -1704,7 +1704,6 @@ TOSA_PASS_SET = {
|
|||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
"HardtanhBackward_basic",
|
||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
|
|
Loading…
Reference in New Issue