From 44c7b181d34dcb1d832fa0d742791fd7fee4be64 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Thu, 28 Apr 2022 13:36:56 +0000 Subject: [PATCH] Revert "[MLIR][TORCH] Add E2E support for aten.ge.float op" This reverts commit 564734b2d7573e66b6b26ac81d35e0df01952102. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ---------------- lib/Conversion/TorchToStd/TorchToStd.cpp | 21 -------------- lib/Dialect/Torch/IR/TorchOps.cpp | 9 ------ .../jit_ir/build_tools/torch_ods_gen.py | 1 - .../test_suite/scalar_comparison.py | 19 ------------ test/Conversion/TorchToStd/basic.mlir | 13 --------- test/Dialect/Torch/canonicalize.mlir | 29 ------------------- 7 files changed, 117 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index dcf49dfed..37fc651c9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6569,31 +6569,6 @@ def Torch_AtenGtFloatOp : Torch_Op<"aten.gt.float", [ let hasFolder = 1; } -def Torch_AtenGeFloatOp : Torch_Op<"aten.ge.float", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::ge.float : (float, float) -> (bool)`"; - let arguments = (ins - Torch_FloatType:$a, - Torch_FloatType:$b - ); - let results = (outs - Torch_BoolType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenGeFloatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenGeFloatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; - let hasFolder = 1; -} - def Torch_AtenLtFloatOp : Torch_Op<"aten.lt.float", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index b4e89410f..2aef720f9 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -93,23 +93,6 @@ public: }; } // namespace -namespace { -// Lowers aten float comparison ops. -template -class ConvertAtenFloatComparisonOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenOp op, - typename OpConversionPattern::OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, Pred, adaptor.a(), - adaptor.b()); - return success(); - } -}; -} // namespace - // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. @@ -209,10 +192,6 @@ public: patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); - target.addIllegalOp(); - patterns.add< - ConvertAtenFloatComparisonOp>( - typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4438247d1..21201e107 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -796,15 +796,6 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef operands) { [](double a, double b) { return a > b; }); } -//===----------------------------------------------------------------------===// -// AtenGeFloatOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenGeFloatOp::fold(ArrayRef operands) { - return floatComparatorFoldHelper(*this, - [](double a, double b) { return a >= b; }); -} - //===----------------------------------------------------------------------===// // AtenEqFloatOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 79ecad386..d3147fd20 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -494,7 +494,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::neg.float : (float) -> (float)") emit("aten::eq.float : (float, float) -> (bool)", has_folder=True) emit("aten::gt.float : (float, float) -> (bool)", has_folder=True) - emit("aten::ge.float : (float, float) -> (bool)", has_folder=True) emit("aten::lt.float : (float, float) -> (bool)", has_folder=True) emit("aten::lt.float_int : (float, int) -> (bool)") emit("aten::__and__.bool : (bool, bool) -> (bool)") diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index f23c375b9..5bfeaf788 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -69,22 +69,3 @@ class GtIntModule(torch.nn.Module): def GtIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) -# ============================================================================== - -class GeFloatModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.float64, True), - ]) - def forward(self, lhs, rhs): - return float(lhs) >= float(rhs) - - -@register_test_case(module_factory=lambda: GeFloatModule()) -def GeFloatModule_basic(module, tu: TestUtils): - module.forward(torch.randn(()), torch.randn(())) diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index 1f9dc2dac..09441d61b 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -166,16 +166,3 @@ func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f %0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float return %0 : !torch.float } - -// CHECK-LABEL: func @torch.aten.ge.float( -// CHECK-SAME: %[[LHS:.*]]: !torch.float, -// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool { -// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] -// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 -// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] -// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool -func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool { - %0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool - return %0 : !torch.bool -} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index afddc97c3..263603d68 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1162,32 +1162,3 @@ func @torch.aten.div.float$fold_cst_operands() -> !torch.float { %0 = torch.aten.div.float %float4, %float2 : !torch.float, !torch.float -> !torch.float return %0 : !torch.float } - -// CHECK-LABEL: func @torch.aten.ge.float$same_operand( -// CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: return %[[TRUE]] : !torch.bool -func @torch.aten.ge.float$same_operand(%arg0: !torch.float) -> !torch.bool { - %2 = torch.aten.ge.float %arg0, %arg0: !torch.float, !torch.float -> !torch.bool - return %2 : !torch.bool -} - -// CHECK-LABEL: func @torch.aten.ge.float$same_value() -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: return %[[TRUE]] : !torch.bool -func @torch.aten.ge.float$same_value() -> !torch.bool { - %float4 = torch.constant.float 4.0 - %float4_0 = torch.constant.float 4.0 - %2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool - return %2 : !torch.bool -} - -// CHECK-LABEL: func @torch.aten.ge.float$different_value() -> !torch.bool { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: return %[[FALSE]] : !torch.bool -func @torch.aten.ge.float$different_value() -> !torch.bool { - %float4 = torch.constant.float 4.0 - %float4_0 = torch.constant.float 5.0 - %2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool - return %2 : !torch.bool -}