[Torch Dialect] add support for AtenIsnanOp (#2170)

* add support for mhlo

* Add Test for torch.ne

* fix torch.ne shape/add static test case

* add support for static torch.ne

---------

Co-authored-by: root <root@n31-177-039.byted.org>
pull/2194/head
JianzheXiao 2023-06-07 10:06:27 +08:00 committed by GitHub
parent 2480cb7a51
commit e4f8fb1b8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 196 additions and 6 deletions

View File

@ -378,6 +378,7 @@ STABLEHLO_PASS_SET = {
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic", "CumsumStaticNegativeDimModule_basic",
"DetachModule_basic", "DetachModule_basic",
"ElementwiseIsnanModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic",
@ -423,6 +424,9 @@ STABLEHLO_PASS_SET = {
"ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatScalarModule_basic", "ElementwiseEqFloatScalarModule_basic",
"ElementwiseEqIntScalarModule_basic", "ElementwiseEqIntScalarModule_basic",
"ElementwiseNeFloatScalarModule_basic",
"ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorStaticModule_basic",
"ElementwiseErfModule_basic", "ElementwiseErfModule_basic",
"ElementwiseGeluModule_basic", "ElementwiseGeluModule_basic",
"ElementwiseGtFloatScalarModule_basic", "ElementwiseGtFloatScalarModule_basic",
@ -443,7 +447,6 @@ STABLEHLO_PASS_SET = {
"ElementwiseMulScalarModule_basic", "ElementwiseMulScalarModule_basic",
"ElementwiseMulScalarModule_float", "ElementwiseMulScalarModule_float",
"ElementwiseMulScalarModule_int", "ElementwiseMulScalarModule_int",
"ElementwiseNeFloatTensorModule_basic",
"ElementwiseNeIntScalarModule_basic", "ElementwiseNeIntScalarModule_basic",
"ElementwiseReciprocalModule_basic", "ElementwiseReciprocalModule_basic",
"ElementwiseRelu6Module_basic", "ElementwiseRelu6Module_basic",
@ -875,6 +878,11 @@ TOSA_PASS_SET = {
"ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatTensorModule_basic", "ElementwiseEqFloatTensorModule_basic",
"ElementwiseEqIntTensorModule_basic", "ElementwiseEqIntTensorModule_basic",
"ElementwiseNeFloatScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
"ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorModule_basic",
"ElementwiseNeIntTensorStaticModule_basic",
"ElementwiseMulScalarModule_int", "ElementwiseMulScalarModule_int",
"ElementwiseMulScalarModule_float", "ElementwiseMulScalarModule_float",
"ElementwiseMulTensorIntModule_basic", "ElementwiseMulTensorIntModule_basic",
@ -885,6 +893,7 @@ TOSA_PASS_SET = {
"ElementwiseMulScalarModule_float", "ElementwiseMulScalarModule_float",
"ElementwiseCeilModule_basic", "ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic", "ElementwiseReciprocalModule_basic",
"ElementwiseIsnanModule_basic",
"TypePromotionAlphaWiderModule_basic", "TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic",
"BatchNorm1DModule_basic", "BatchNorm1DModule_basic",
@ -955,7 +964,6 @@ TOSA_PASS_SET = {
"ElementwiseGeluModule_basic", "ElementwiseGeluModule_basic",
"GeluBackwardModule_basic", "GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic", "ElementwiseNeIntScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
"Convolution2DStaticModule_basic", "Convolution2DStaticModule_basic",
"ElementwiseNegModule_basic", "ElementwiseNegModule_basic",
"TestMultipleTensorReturn_basic", "TestMultipleTensorReturn_basic",

View File

@ -6524,6 +6524,29 @@ def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [
}]; }];
} }
def Torch_AtenIsnanOp : Torch_Op<"aten.isnan", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::isnan : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIsnanOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIsnanOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenAllOp : Torch_Op<"aten.all", [ def Torch_AtenAllOp : Torch_Op<"aten.all", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -142,7 +142,8 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
std::is_same<OpTy, AtenLeTensorOp>() || std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() || std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() || std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>(), std::is_same<OpTy, AtenEqTensorOp>() ||
std::is_same<OpTy, AtenNeTensorOp>(),
"unimplemented: op type not supported"); "unimplemented: op type not supported");
Type lhsDtype = lhs.getType(); Type lhsDtype = lhs.getType();
@ -172,6 +173,9 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) { if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs); return createEqual(b, loc, elementalType, lhs, rhs);
} }
if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported"); llvm_unreachable("unimplemented: op type not supported");
} }
@ -595,6 +599,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
payloadArgs[1]); payloadArgs[1]);
} }
if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
return createCompareTensorOp(b, loc, neTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) { if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands); AtenDivTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(div.getType()) Type dtype = converter->convertType(div.getType())
@ -1156,7 +1164,7 @@ public:
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
@ -1689,7 +1697,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,

View File

@ -6341,6 +6341,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.isnan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.eq.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.eq.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -8627,6 +8635,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n" " %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.isnan\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n" " %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"

View File

@ -351,6 +351,20 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenIsnanOp : public OpRewritePattern<AtenIsnanOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIsnanOp op,
PatternRewriter &rewriter) const override {
Value input = op.getSelf();
// Create a new aten.ne operation with the same type and input value.
rewriter.replaceOpWithNewOp<AtenNeTensorOp>(op, op.getType(), input, input);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> { class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public: public:
@ -4572,6 +4586,7 @@ public:
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns); DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);

View File

@ -422,6 +422,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenBernoulliPOp>(); target.addIllegalOp<AtenBernoulliPOp>();
target.addIllegalOp<AtenBernoulliTensorOp>(); target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenZeroOp>(); target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenIsnanOp>();
target.addIllegalOp<AtenRandLikeOp>(); target.addIllegalOp<AtenRandLikeOp>();
target.addIllegalOp<AtenHardsigmoidOp>(); target.addIllegalOp<AtenHardsigmoidOp>();
target.addIllegalOp<AtenRelu6Op>(); target.addIllegalOp<AtenRelu6Op>();

View File

@ -218,6 +218,12 @@ def atenlift_fresh_copy〡shape(self: List[int]) -> List[int]:
def aten_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: def aten_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
return upstream_shape_functions.unary(grad_output) return upstream_shape_functions.unary(grad_output)
def atenisnan〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenneTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def ateneqScalar〡shape(self: List[int], other: float) -> List[int]: def ateneqScalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
@ -1976,6 +1982,14 @@ def atenltTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
def atenleTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: def atenleTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
return torch.bool return torch.bool
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenisnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
def atenneTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
return torch.bool
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))

View File

@ -466,6 +466,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
emit("aten::isnan : (Tensor) -> (Tensor)")
emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)")
emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::all.bool : (bool[]) -> (bool)")
emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any : (Tensor) -> (Tensor)")

View File

@ -553,7 +553,7 @@ class ElementwiseNeFloatScalarModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule()) @register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils): def ElementwiseNeFloatScalarModule_basic(module, tu: TestUtils):
module.forward( module.forward(
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32)) torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
@ -578,6 +578,90 @@ def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseNeFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ne(x, y)
@register_test_case(module_factory=lambda: ElementwiseNeFloatTensorModule())
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32))
# ==============================================================================
class ElementwiseNeIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.ne(x, y)
@register_test_case(module_factory=lambda: ElementwiseNeIntTensorModule())
def ElementwiseNeIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
# ==============================================================================
class ElementwiseNeFloatTensorStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3], torch.float32, True),
([2, 3], torch.float32, True),
])
def forward(self, x, y):
return torch.ne(x, y)
@register_test_case(module_factory=lambda: ElementwiseNeFloatTensorStaticModule())
def ElementwiseNeFloatTensorStaticModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32))
# ==============================================================================
class ElementwiseNeIntTensorStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([8, 5], torch.int64, True),
([5], torch.int64, True),
])
def forward(self, x, y):
return torch.ne(x, y)
@register_test_case(module_factory=lambda: ElementwiseNeIntTensorStaticModule())
def ElementwiseNeIntTensorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
# ==============================================================================
class AnyBoolTrueModule(torch.nn.Module): class AnyBoolTrueModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -651,3 +735,23 @@ class AllBoolFalseModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AllBoolFalseModule()) @register_test_case(module_factory=lambda: AllBoolFalseModule())
def AllBoolFalseModule_basic(module, tu: TestUtils): def AllBoolFalseModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# ==============================================================================
class ElementwiseIsnanModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.isnan(x)
@register_test_case(module_factory=lambda: ElementwiseIsnanModule())
def ElementwiseIsnanModule_basic(module, tu: TestUtils):
x = torch.full((1, 1, 32), torch.nan)
module.forward(x)