[Torch] Decompose AtenMaskedScatterOp (#3353)

Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com>
pull/3085/head
Xinyu Yang 2024-05-16 15:27:25 +08:00 committed by GitHub
parent a9edefb3cf
commit 7faba75696
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 143 additions and 0 deletions

View File

@ -72,6 +72,9 @@ bool isBuiltInType(Type type);
// std::nullopt is returned if the tensorRank can't be determined. // std::nullopt is returned if the tensorRank can't be determined.
std::optional<unsigned> getTensorRank(Value tensor); std::optional<unsigned> getTensorRank(Value tensor);
// Helper function to get the number of elements in a tensor.
std::optional<int64_t> getTensorNumel(Value tensor);
bool isViewLikeOp(Operation *op); bool isViewLikeOp(Operation *op);
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,

View File

@ -3371,6 +3371,104 @@ public:
}; };
} // namespace } // namespace
// Decompose aten.masked_scatter:
// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor:
// mask_int = mask + torch.zeros_like(self)
// prefix_sum = torch.cumsum(mask_int.flatten(), dim=0)
// mask_prefix = torch.clamp(prefix_sum - 1, min=0)
// mask = mask.to(torch.bool)
// source = source.flatten()[mask_prefix].reshape(mask.shape)
// return torch.where(mask, source, self)
namespace {
class DecomposeAtenMaskedScatterOp
: public OpRewritePattern<AtenMaskedScatterOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedScatterOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
Value mask = op.getMask();
Value source = op.getSource();
Value self = op.getSelf();
auto selfTy = cast<BaseTensorType>(self.getType());
auto resTy = cast<BaseTensorType>(op.getType());
auto sourceTy = cast<BaseTensorType>(source.getType());
if (!resTy || !resTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
if (!selfTy || !selfTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");
if (!sourceTy || !sourceTy.areAllSizesKnown() || !sourceTy.hasDtype())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");
int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes
int64_t sourceNumel =
getTensorNumel(source).value(); // as sourceTy has sizes
int64_t selfRank = selfTy.getSizes().size();
int64_t sourceRank = sourceTy.getSizes().size();
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constNone = rewriter.create<ConstantNoneOp>(loc);
Value selfLastDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(selfRank - 1));
Value sourceLastDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sourceRank - 1));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto selfIntType = selfTy.getWithSizesAndDtype(selfTy.getSizes(), si64Type);
Value zerosLike = rewriter.create<Torch::AtenZerosLikeOp>(
loc, selfIntType, self, int64Dtype, constNone, constNone, constNone,
constNone);
Value maskInt = rewriter.create<Torch::AtenAddTensorOp>(
loc, selfIntType, mask, zerosLike, constOne);
auto flattenMaskedType = selfTy.getWithSizesAndDtype(
/*optionalSizes=*/{selfNumel}, si64Type);
Value maskIntFlatten = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenMaskedType, maskInt, constZero, selfLastDim);
Value prefixSum = rewriter.create<Torch::AtenCumsumOp>(
loc, flattenMaskedType, maskIntFlatten,
/*dim=*/constZero, constNone);
Value prefixSumMinusOne = rewriter.create<Torch::AtenSubScalarOp>(
loc, flattenMaskedType, prefixSum, constOne, constOne);
Value maskPrefix = rewriter.create<Torch::AtenClampOp>(
loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero,
/*max=*/constNone);
auto sourceFlattenType = sourceTy.getWithSizesAndDtype(
/*optionalSizes=*/{sourceNumel}, sourceTy.getDtype());
Value sourceFlatten = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, sourceFlattenType, source, constZero, sourceLastDim);
auto selectSourceType = sourceTy.getWithSizesAndDtype(
/*optionalSizes=*/{selfNumel}, sourceTy.getDtype());
Value selectSource = rewriter.create<Torch::AtenIndexSelectOp>(
loc, selectSourceType, sourceFlatten, constZero, maskPrefix);
// Reshape normalized output back to the original input shape
auto selfShape = rewriter.create<AtenSizeOp>(
loc, Torch::ListType::get(IntType::get(context)), self);
Value sourceReshape = rewriter.create<Torch::AtenViewOp>(
loc, selfTy, selectSource, selfShape);
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(op, resTy, mask,
sourceReshape, self);
return success();
}
};
} // namespace
// Decompose aten._convolution-like to aten.convolution // Decompose aten._convolution-like to aten.convolution
namespace { namespace {
template <typename ConvolutionLikeOp> template <typename ConvolutionLikeOp>
@ -7839,6 +7937,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>( addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(

View File

@ -390,6 +390,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOtherOp>(); target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>(); target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>(); target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenMaskedScatterOp>();
target.addIllegalOp<AtenSizeOp>(); target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>(); target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>(); target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();

View File

@ -209,6 +209,19 @@ std::optional<unsigned> Torch::getTensorRank(Value tensor) {
return tensorType.getSizes().size(); return tensorType.getSizes().size();
} }
std::optional<int64_t> Torch::getTensorNumel(Value tensor) {
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
if (!tensorType.hasSizes())
return std::nullopt;
int64_t numel = 1;
for (auto dim : tensorType.getSizes()) {
if (dim == ShapedType::kDynamic)
return ShapedType::kDynamic;
numel *= dim;
}
return numel;
}
bool Torch::isViewLikeOp(Operation *op) { bool Torch::isViewLikeOp(Operation *op) {
// AtenContiguousOp might return a view, so this is conservatively // AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases // correct. We could potentially be more precise and identify the cases

View File

@ -1072,6 +1072,7 @@ STABLEHLO_PASS_SET = {
"LinspaceTwoSizeModule_basic", "LinspaceTwoSizeModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarFloatValueStaticModule_basic",
"MaskedFillScalarIntValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic",
"MaskedScatterStaticBasic_basic",
"Matmul4dStatic_basic", "Matmul4dStatic_basic",
"Matmul_2d", "Matmul_2d",
"Matmul_dot", "Matmul_dot",
@ -2366,6 +2367,7 @@ ONNX_XFAIL_SET = {
"LinalgNormKeepDimComplexModule_basic", "LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic", "LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic", "LogSoftmaxBackwardModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic", "MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic", "MaxPool1dModule_basic",

View File

@ -12,6 +12,31 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class MaskedScatterStaticBasic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([4, 4], torch.float32, True),
([4, 4], torch.bool, True),
([8, 8], torch.float32, True),
]
)
def forward(self, x, mask, y):
return torch.masked_scatter(x, mask, y)
@register_test_case(module_factory=lambda: MaskedScatterStaticBasic())
def MaskedScatterStaticBasic_basic(module, tu: TestUtils):
x = torch.rand(4, 4)
mask = torch.rand(4, 4) > 0.5
y = torch.rand(8, 8)
module.forward(x, mask, y)
class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module): class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()