diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 072fdd7df..7c804bcd0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8024,6 +8024,29 @@ def Torch_AtenIsnanOp : Torch_Op<"aten.isnan", [ }]; } +def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isinf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsinfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsinfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAllOp : Torch_Op<"aten.all", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 9c862e410..7f6f55e87 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -89,7 +89,7 @@ static Value createLessThanOrEqual(OpBuilder &b, Location loc, static Value createEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f6bdff8b4..4e38970bb 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6507,6 +6507,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isinf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9186,6 +9190,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isinf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1a61cf237..2e0020b05 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -530,6 +530,26 @@ public: }; } // namespace +namespace { +class DecomposeAtenIsinfOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsinfOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); + Value abs = rewriter.create(loc, self.getType(), self); + rewriter.replaceOpWithNewOp(op, op.getType(), abs, inf); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: @@ -5458,6 +5478,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 4b823e517..627a33a8e 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -426,6 +426,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0b7029fbb..624831b8d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1066,6 +1066,7 @@ TOSA_PASS_SET = { "ElementwiseCeilModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseIsnanModule_basic", + "ElementwiseIsinfModule_basic", "TypePromotionAlphaWiderModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise", @@ -1460,4 +1461,5 @@ LTC_XFAIL_SET = { "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseIsinfModule_basic", } diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index aa931c2de..41c1f4da5 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -243,6 +243,9 @@ def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: Li def aten〇isnan〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇isinf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -2236,6 +2239,10 @@ def aten〇le〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + @check_dtype_function(_check_two_tensor_op()) def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 55e124d92..d4037ba7e 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -546,6 +546,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::isnan : (Tensor) -> (Tensor)") + emit("aten::isinf : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 67b168e7b..57a549309 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -454,7 +454,7 @@ class ElementwiseEqFloatScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule()) def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils): module.forward( - torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32)) + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32)) # ============================================================================== @@ -534,7 +534,7 @@ class ElementwiseEqFloatTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule()) def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils): module.forward( - torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([[1.0, 2.2, 6.0], [torch.nan, 2.0, 3.1]]).to(torch.float32), torch.tensor([1.0, 2.4, 6.0]).to(torch.float32)) # ============================================================================== @@ -575,7 +575,7 @@ class ElementwiseNeFloatScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule()) def ElementwiseNeFloatScalarModule_basic(module, tu: TestUtils): module.forward( - torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32)) + torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32)) # ============================================================================== @@ -765,7 +765,7 @@ class ElementwiseIsnanModule(torch.nn.Module): @export @annotate_args([ None, - ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), ]) def forward(self, x): return torch.ops.aten.isnan(x) @@ -773,5 +773,25 @@ class ElementwiseIsnanModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseIsnanModule()) def ElementwiseIsnanModule_basic(module, tu: TestUtils): - x = torch.full((1, 1, 32), torch.nan) + x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf]) + module.forward(x) + +# ============================================================================== + +class ElementwiseIsinfModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isinf(x) + + +@register_test_case(module_factory=lambda: ElementwiseIsinfModule()) +def ElementwiseIsinfModule_basic(module, tu: TestUtils): + x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf]) module.forward(x)