mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] fix aten.nan_to_num's decomposition when inf=None (#3530)
also add shape infer in decomposition, see https://github.com/llvm/torch-mlir/issues/3312pull/3541/head
parent
5342aa70cf
commit
b38585e077
|
@ -3906,37 +3906,50 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenNanToNumOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
mlir::FloatType f64Type = rewriter.getF64Type();
|
||||
Value nan = op.getNan();
|
||||
Value posinf = op.getPosinf();
|
||||
Value neginf = op.getNeginf();
|
||||
auto baseType =
|
||||
ValueTensorType::getWithLeastStaticInformation(op.getContext());
|
||||
if (dyn_cast_or_null<ConstantNoneOp>(nan.getDefiningOp()))
|
||||
nan = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getFloatAttr(
|
||||
f64Type, APFloat::getZero(f64Type.getFloatSemantics())));
|
||||
if (dyn_cast_or_null<ConstantNoneOp>(posinf.getDefiningOp()))
|
||||
auto outputType = cast<BaseTensorType>(op.getResult().getType());
|
||||
if (!outputType.hasDtype() ||
|
||||
!isa<mlir::FloatType>(outputType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expect output type to have float dtype");
|
||||
}
|
||||
mlir::FloatType outputElementType =
|
||||
cast<mlir::FloatType>(outputType.getDtype());
|
||||
|
||||
if (isa<Torch::NoneType>(nan.getType())) {
|
||||
nan =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
}
|
||||
if (isa<Torch::NoneType>(posinf.getType())) {
|
||||
posinf = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getFloatAttr(
|
||||
f64Type, APFloat::getInf(f64Type.getFloatSemantics())));
|
||||
if (dyn_cast_or_null<ConstantNoneOp>(neginf.getDefiningOp()))
|
||||
loc, rewriter.getF64FloatAttr(
|
||||
APFloat::getLargest(outputElementType.getFloatSemantics())
|
||||
.convertToDouble()));
|
||||
}
|
||||
if (isa<Torch::NoneType>(neginf.getType())) {
|
||||
neginf = rewriter.create<ConstantFloatOp>(
|
||||
loc,
|
||||
rewriter.getFloatAttr(
|
||||
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
|
||||
loc, rewriter.getF64FloatAttr(
|
||||
APFloat::getLargest(outputElementType.getFloatSemantics(),
|
||||
/*Negative=*/true)
|
||||
.convertToDouble()));
|
||||
}
|
||||
|
||||
auto compareType = outputType.getWithSizesAndDtype(
|
||||
outputType.getOptionalSizes(), rewriter.getI1Type());
|
||||
Value isNan =
|
||||
rewriter.create<Torch::AtenIsnanOp>(loc, baseType, op.getSelf());
|
||||
rewriter.create<Torch::AtenIsnanOp>(loc, compareType, op.getSelf());
|
||||
Value where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
|
||||
loc, baseType, isNan, nan, op.getSelf());
|
||||
loc, outputType, isNan, nan, op.getSelf());
|
||||
Value isposinf =
|
||||
rewriter.create<Torch::AtenIsposinfOp>(loc, baseType, where);
|
||||
rewriter.create<Torch::AtenIsposinfOp>(loc, compareType, where);
|
||||
where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
|
||||
loc, baseType, isposinf, posinf, where);
|
||||
loc, outputType, isposinf, posinf, where);
|
||||
Value isneginf =
|
||||
rewriter.create<Torch::AtenIsneginfOp>(loc, baseType, where);
|
||||
rewriter.create<Torch::AtenIsneginfOp>(loc, compareType, where);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenWhereScalarSelfOp>(
|
||||
op, op.getType(), isneginf, neginf, where);
|
||||
op, outputType, isneginf, neginf, where);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1029,6 +1029,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseLog2Module_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseNanToNumWithNoneModule_Basic",
|
||||
"ElementwiseNanToNumModule_Basic",
|
||||
"ElementwiseNeFloatTensorStaticModule_basic",
|
||||
"ElementwiseNeIntTensorStaticModule_basic",
|
||||
|
@ -1761,6 +1762,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"ElementwiseNanToNumWithNoneModule_Basic",
|
||||
"ElementwiseNanToNumModule_Basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleI32Static_basic",
|
||||
|
|
|
@ -610,6 +610,29 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNanToNumWithNoneModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([3, 4], torch.float32, True)])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.nan_to_num(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseNanToNumWithNoneModule())
|
||||
def ElementwiseNanToNumWithNoneModule_Basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor(
|
||||
[
|
||||
[float("nan"), 0.0, float("nan"), 1.0],
|
||||
[float("inf"), 2.0, float("inf"), 3.0],
|
||||
[float("-inf"), -1.0, float("-inf"), 4.0],
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ElementwiseNanToNumModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -617,7 +640,7 @@ class ElementwiseNanToNumModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([None, ([3, 4], torch.float32, True)])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0)
|
||||
return torch.ops.aten.nan_to_num(a, 0.1, 1.0, -1.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseNanToNumModule())
|
||||
|
@ -625,9 +648,9 @@ def ElementwiseNanToNumModule_Basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
torch.tensor(
|
||||
[
|
||||
[float("nan"), 0.0, float("nan"), 0.0],
|
||||
[float("inf"), 0.0, float("inf"), 0.0],
|
||||
[float("-inf"), 0.0, float("-inf"), 0.0],
|
||||
[float("nan"), 0.0, float("nan"), 1.0],
|
||||
[float("inf"), 2.0, float("inf"), 3.0],
|
||||
[float("-inf"), -1.0, float("-inf"), 4.0],
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue