[Torch] Add AtenMaskedFillTensorOp support (#3561)

pull/3565/head
yyp0 2024-07-26 15:32:13 +08:00 committed by GitHub
parent 15cf7106c4
commit ea60d72489
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 0 deletions

View File

@ -4252,6 +4252,26 @@ public:
}; };
} // namespace } // namespace
// Decompose aten.masked_fill.Tensor into aten.where.self op.
namespace {
class DecomposeAtenMaskedFillTensorOp
: public OpRewritePattern<AtenMaskedFillTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedFillTensorOp op,
PatternRewriter &rewriter) const override {
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getMask(),
op.getValue(), op.getSelf());
return success();
}
};
} // namespace
// Decompose aten.masked_scatter: // Decompose aten.masked_scatter:
// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor: // def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor:
// mask_int = mask + torch.zeros_like(self) // mask_int = mask + torch.zeros_like(self)
@ -9182,6 +9202,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedScatterOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);

View File

@ -389,6 +389,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOtherOp>(); target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>(); target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>(); target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenMaskedFillTensorOp>();
target.addIllegalOp<AtenMaskedScatterOp>(); target.addIllegalOp<AtenMaskedScatterOp>();
target.addIllegalOp<AtenSizeOp>(); target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>(); target.addIllegalOp<AtenReshapeOp>();

View File

@ -1118,6 +1118,7 @@ STABLEHLO_PASS_SET = {
"LinspaceTwoSizeModule_basic", "LinspaceTwoSizeModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarFloatValueStaticModule_basic",
"MaskedFillScalarIntValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillTensorIntValueStaticModule_basic",
"MaskedScatterStaticBasic_basic", "MaskedScatterStaticBasic_basic",
"Matmul4dStatic_basic", "Matmul4dStatic_basic",
"Matmul_2d", "Matmul_2d",