mirror of https://github.com/llvm/torch-mlir
[MHLO] Support AtenMaskedFillScalar (#1839)
* [MHLO] Support MaskedFillScalar. * Update. * Update. * Update. --------- Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1875/head
parent
2f6fdb7f0b
commit
f1b8d5e581
|
@ -92,6 +92,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"MaskedFillScalarIntValueStaticModule_basic",
|
||||
"MaskedFillScalarFloatValueStaticModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AddSizeIntModule_basic",
|
||||
"AddSizeIntNegDimModule_basic",
|
||||
|
@ -631,6 +633,7 @@ TOSA_PASS_SET = {
|
|||
"_LogSoftmaxModuleStable_basic",
|
||||
"ElementwiseAtenWhereSelfModule_basic",
|
||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||
"MaskedFillScalarIntValueModule_basic",
|
||||
"MaskedFillScalarIntValueStaticModule_basic",
|
||||
"MaskedFillTensorIntValueStaticModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
|
|
|
@ -984,18 +984,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
.getElementType();
|
||||
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
||||
}
|
||||
if (auto maskedFillScalar = dyn_cast<AtenMaskedFillScalarOp>(op)) {
|
||||
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(maskedFillScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
Value input = payloadArgs[0];
|
||||
Value mask = payloadArgs[1];
|
||||
Value fillValue = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
||||
|
||||
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
||||
}
|
||||
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
|
||||
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(maskedFillTensor.getType())
|
||||
|
@ -1105,7 +1093,7 @@ public:
|
|||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
||||
|
@ -1583,7 +1571,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
|
||||
|
|
|
@ -1390,6 +1390,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.masked_fill.Scalar into aten.where.self op.
|
||||
namespace {
|
||||
class DecomposeAtenMaskedFillScalarOp
|
||||
: public OpRewritePattern<AtenMaskedFillScalarOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
Value mask = op.getMask();
|
||||
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
|
||||
value, op.getSelf());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
// Decompose aten.convolution_overrideable to aten.convolution op.
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionOverrideableOp
|
||||
|
@ -3821,6 +3840,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
||||
|
|
|
@ -349,6 +349,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenWhereScalarOp>();
|
||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
||||
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
target.addIllegalOp<AtenReshapeOp>();
|
||||
|
|
|
@ -1382,6 +1382,27 @@ def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils):
|
|||
tu.randint(2, 3, high=2).to(dtype=torch.bool))
|
||||
|
||||
|
||||
class MaskedFillScalarFloatValueStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3], torch.int64, True),
|
||||
([2, 3], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, mask):
|
||||
return torch.ops.aten.masked_fill(x, mask, value=-0.01)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaskedFillScalarFloatValueStaticModule())
|
||||
def MaskedFillScalarFloatValueStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 3, low=-10, high=10),
|
||||
tu.randint(2, 3, high=2).to(dtype=torch.bool))
|
||||
|
||||
|
||||
class MaskedFillTensorFloatValueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue