[MHLO] Support AtenMaskedFillScalar (#1839)

* [MHLO] Support MaskedFillScalar.

* Update.

* Update.

* Update.

---------

Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
pull/1875/head
Ziheng Jiang 2023-02-10 13:58:39 -08:00 committed by GitHub
parent 2f6fdb7f0b
commit f1b8d5e581
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 14 deletions

View File

@ -92,6 +92,8 @@ TORCHDYNAMO_XFAIL_SET = {
}
STABLEHLO_PASS_SET = {
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
@ -631,6 +633,7 @@ TOSA_PASS_SET = {
"_LogSoftmaxModuleStable_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"MaskedFillScalarIntValueModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillTensorIntValueStaticModule_basic",
"ElementwiseAddScalarInt64Module_basic",

View File

@ -984,18 +984,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
}
if (auto maskedFillScalar = dyn_cast<AtenMaskedFillScalarOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(maskedFillScalar.getType())
.cast<RankedTensorType>()
.getElementType();
Value input = payloadArgs[0];
Value mask = payloadArgs[1];
Value fillValue = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
}
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(maskedFillTensor.getType())
@ -1105,7 +1093,7 @@ public:
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
@ -1583,7 +1571,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();

View File

@ -1390,6 +1390,25 @@ public:
};
} // namespace
// Decompose aten.masked_fill.Scalar into aten.where.self op.
namespace {
class DecomposeAtenMaskedFillScalarOp
: public OpRewritePattern<AtenMaskedFillScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
Value mask = op.getMask();
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
value, op.getSelf());
return success();
}
};
} // namespace
// Decompose aten.convolution_overrideable to aten.convolution op.
namespace {
class DecomposeAtenConvolutionOverrideableOp
@ -3821,6 +3840,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);

View File

@ -349,6 +349,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOp>();
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>();

View File

@ -1382,6 +1382,27 @@ def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils):
tu.randint(2, 3, high=2).to(dtype=torch.bool))
class MaskedFillScalarFloatValueStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3], torch.int64, True),
([2, 3], torch.bool, True),
])
def forward(self, x, mask):
return torch.ops.aten.masked_fill(x, mask, value=-0.01)
@register_test_case(module_factory=lambda: MaskedFillScalarFloatValueStaticModule())
def MaskedFillScalarFloatValueStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 3, low=-10, high=10),
tu.randint(2, 3, high=2).to(dtype=torch.bool))
class MaskedFillTensorFloatValueModule(torch.nn.Module):
def __init__(self):