From ea60d724891f9f19017108fdd863346a01106b8b Mon Sep 17 00:00:00 2001 From: yyp0 Date: Fri, 26 Jul 2024 15:32:13 +0800 Subject: [PATCH] [Torch] Add AtenMaskedFillTensorOp support (#3561) --- .../Torch/Transforms/DecomposeComplexOps.cpp | 21 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 1 + 3 files changed, 23 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index faf7f7ce2..f073d1405 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4252,6 +4252,26 @@ public: }; } // namespace +// Decompose aten.masked_fill.Tensor into aten.where.self op. +namespace { +class DecomposeAtenMaskedFillTensorOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMaskedFillTensorOp op, + PatternRewriter &rewriter) const override { + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + rewriter.replaceOpWithNewOp(op, resType, op.getMask(), + op.getValue(), op.getSelf()); + + return success(); + } +}; +} // namespace + // Decompose aten.masked_scatter: // def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor: // mask_int = mask + torch.zeros_like(self) @@ -9182,6 +9202,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 31ad13158..161f9516f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -389,6 +389,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fd8f7fc07..a82a2e913 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1118,6 +1118,7 @@ STABLEHLO_PASS_SET = { "LinspaceTwoSizeModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic", + "MaskedFillTensorIntValueStaticModule_basic", "MaskedScatterStaticBasic_basic", "Matmul4dStatic_basic", "Matmul_2d",