[Torch Dialect] fix aten.nan_to_num's decomposition when inf=None (#3530)

also add shape infer in decomposition, see
https://github.com/llvm/torch-mlir/issues/3312
pull/3541/head
Yuanqiang Liu 2024-07-11 08:46:40 +08:00 committed by GitHub
parent 5342aa70cf
commit b38585e077
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 24 deletions

View File

@ -3906,37 +3906,50 @@ public:
LogicalResult matchAndRewrite(AtenNanToNumOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
mlir::FloatType f64Type = rewriter.getF64Type();
Value nan = op.getNan();
Value posinf = op.getPosinf();
Value neginf = op.getNeginf();
auto baseType =
ValueTensorType::getWithLeastStaticInformation(op.getContext());
if (dyn_cast_or_null<ConstantNoneOp>(nan.getDefiningOp()))
nan = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(
f64Type, APFloat::getZero(f64Type.getFloatSemantics())));
if (dyn_cast_or_null<ConstantNoneOp>(posinf.getDefiningOp()))
auto outputType = cast<BaseTensorType>(op.getResult().getType());
if (!outputType.hasDtype() ||
!isa<mlir::FloatType>(outputType.getDtype())) {
return rewriter.notifyMatchFailure(
op, "expect output type to have float dtype");
}
mlir::FloatType outputElementType =
cast<mlir::FloatType>(outputType.getDtype());
if (isa<Torch::NoneType>(nan.getType())) {
nan =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
}
if (isa<Torch::NoneType>(posinf.getType())) {
posinf = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(
f64Type, APFloat::getInf(f64Type.getFloatSemantics())));
if (dyn_cast_or_null<ConstantNoneOp>(neginf.getDefiningOp()))
loc, rewriter.getF64FloatAttr(
APFloat::getLargest(outputElementType.getFloatSemantics())
.convertToDouble()));
}
if (isa<Torch::NoneType>(neginf.getType())) {
neginf = rewriter.create<ConstantFloatOp>(
loc,
rewriter.getFloatAttr(
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
loc, rewriter.getF64FloatAttr(
APFloat::getLargest(outputElementType.getFloatSemantics(),
/*Negative=*/true)
.convertToDouble()));
}
auto compareType = outputType.getWithSizesAndDtype(
outputType.getOptionalSizes(), rewriter.getI1Type());
Value isNan =
rewriter.create<Torch::AtenIsnanOp>(loc, baseType, op.getSelf());
rewriter.create<Torch::AtenIsnanOp>(loc, compareType, op.getSelf());
Value where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
loc, baseType, isNan, nan, op.getSelf());
loc, outputType, isNan, nan, op.getSelf());
Value isposinf =
rewriter.create<Torch::AtenIsposinfOp>(loc, baseType, where);
rewriter.create<Torch::AtenIsposinfOp>(loc, compareType, where);
where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
loc, baseType, isposinf, posinf, where);
loc, outputType, isposinf, posinf, where);
Value isneginf =
rewriter.create<Torch::AtenIsneginfOp>(loc, baseType, where);
rewriter.create<Torch::AtenIsneginfOp>(loc, compareType, where);
rewriter.replaceOpWithNewOp<Torch::AtenWhereScalarSelfOp>(
op, op.getType(), isneginf, neginf, where);
op, outputType, isneginf, neginf, where);
return success();
}
};

View File

@ -1029,6 +1029,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseLog2Module_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseNanToNumWithNoneModule_Basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorStaticModule_basic",
@ -1761,6 +1762,7 @@ TOSA_PASS_SET = {
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseNanToNumWithNoneModule_Basic",
"ElementwiseNanToNumModule_Basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",

View File

@ -610,6 +610,29 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseNanToNumWithNoneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([3, 4], torch.float32, True)])
def forward(self, a):
return torch.ops.aten.nan_to_num(a)
@register_test_case(module_factory=lambda: ElementwiseNanToNumWithNoneModule())
def ElementwiseNanToNumWithNoneModule_Basic(module, tu: TestUtils):
module.forward(
torch.tensor(
[
[float("nan"), 0.0, float("nan"), 1.0],
[float("inf"), 2.0, float("inf"), 3.0],
[float("-inf"), -1.0, float("-inf"), 4.0],
]
)
)
class ElementwiseNanToNumModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -617,7 +640,7 @@ class ElementwiseNanToNumModule(torch.nn.Module):
@export
@annotate_args([None, ([3, 4], torch.float32, True)])
def forward(self, a):
return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0)
return torch.ops.aten.nan_to_num(a, 0.1, 1.0, -1.0)
@register_test_case(module_factory=lambda: ElementwiseNanToNumModule())
@ -625,9 +648,9 @@ def ElementwiseNanToNumModule_Basic(module, tu: TestUtils):
module.forward(
torch.tensor(
[
[float("nan"), 0.0, float("nan"), 0.0],
[float("inf"), 0.0, float("inf"), 0.0],
[float("-inf"), 0.0, float("-inf"), 0.0],
[float("nan"), 0.0, float("nan"), 1.0],
[float("inf"), 2.0, float("inf"), 3.0],
[float("-inf"), -1.0, float("-inf"), 4.0],
]
)
)