From 708a51ae2e70d3d4756b89200e26f42a6c753177 Mon Sep 17 00:00:00 2001 From: Albert Sandru Date: Fri, 3 Jun 2022 15:56:08 +0000 Subject: [PATCH] Add E2E support for aten.is_floating_point --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 ++++++++++ lib/Conversion/TorchToStd/TorchToStd.cpp | 20 +++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 44 +++++++++++++++++++ 4 files changed, 88 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 91115490a..c18a45dce 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4149,6 +4149,29 @@ def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [ }]; } +def Torch_AtenIsFloatingPointOp : Torch_Op<"aten.is_floating_point", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::is_floating_point : (Tensor) -> (bool)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsFloatingPointOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsFloatingPointOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 60627a167..00b969f79 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -50,6 +50,24 @@ public: }; } // namespace +namespace { +class ConvertAtenIsFloatingPointOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenIsFloatingPointOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tensorType = op.self().getType().cast(); + bool result = + tensorType.hasDtype() && tensorType.getDtype().isa(); + rewriter.replaceOpWithNewOp( + op, BoolAttr::get(getContext(), result)); + return success(); + } +}; +} // namespace + namespace { class ConvertRuntimeAssertOp : public OpConversionPattern { public: @@ -301,6 +319,8 @@ public: RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 42a266b82..4ca88a3c9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -392,6 +392,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") + emit("aten::is_floating_point : (Tensor) -> (bool)") emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index df8991ba6..b186b9b08 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -64,6 +64,50 @@ def BmmModule_basic(module, tu: TestUtils): # ============================================================================== +class IsFloatingPointInt(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.is_floating_point(x) + + +@register_test_case(module_factory=lambda: IsFloatingPointInt()) +def IsFloatingPointInt_False(module, tu: TestUtils): + module.forward(torch.randint(100, (3, 3))) + + +# ============================================================================== + + +class IsFloatingPointFloat(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + def forward(self, x): + return torch.is_floating_point(x) + + +@register_test_case(module_factory=lambda: IsFloatingPointFloat()) +def IsFloatingPointFloat_True(module, tu: TestUtils): + module.forward(tu.rand(3)) + + +# ============================================================================== + + # A subgraph with multiple mm ops. class MmDagModule(torch.nn.Module):