mirror of https://github.com/llvm/torch-mlir
[Torch] Add AtenMaskedFillTensorOp support (#3561)
parent
15cf7106c4
commit
ea60d72489
|
@ -4252,6 +4252,26 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose aten.masked_fill.Tensor into aten.where.self op.
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenMaskedFillTensorOp
|
||||||
|
: public OpRewritePattern<AtenMaskedFillTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenMaskedFillTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getMask(),
|
||||||
|
op.getValue(), op.getSelf());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.masked_scatter:
|
// Decompose aten.masked_scatter:
|
||||||
// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor:
|
// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor:
|
||||||
// mask_int = mask + torch.zeros_like(self)
|
// mask_int = mask + torch.zeros_like(self)
|
||||||
|
@ -9182,6 +9202,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillTensorOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedScatterOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedScatterOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
||||||
|
|
|
@ -389,6 +389,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||||
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
||||||
|
target.addIllegalOp<AtenMaskedFillTensorOp>();
|
||||||
target.addIllegalOp<AtenMaskedScatterOp>();
|
target.addIllegalOp<AtenMaskedScatterOp>();
|
||||||
target.addIllegalOp<AtenSizeOp>();
|
target.addIllegalOp<AtenSizeOp>();
|
||||||
target.addIllegalOp<AtenReshapeOp>();
|
target.addIllegalOp<AtenReshapeOp>();
|
||||||
|
|
|
@ -1118,6 +1118,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"LinspaceTwoSizeModule_basic",
|
"LinspaceTwoSizeModule_basic",
|
||||||
"MaskedFillScalarFloatValueStaticModule_basic",
|
"MaskedFillScalarFloatValueStaticModule_basic",
|
||||||
"MaskedFillScalarIntValueStaticModule_basic",
|
"MaskedFillScalarIntValueStaticModule_basic",
|
||||||
|
"MaskedFillTensorIntValueStaticModule_basic",
|
||||||
"MaskedScatterStaticBasic_basic",
|
"MaskedScatterStaticBasic_basic",
|
||||||
"Matmul4dStatic_basic",
|
"Matmul4dStatic_basic",
|
||||||
"Matmul_2d",
|
"Matmul_2d",
|
||||||
|
|
Loading…
Reference in New Issue