diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7e718355c..844ddca29 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6934,30 +6934,6 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [ }]; } -def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::ceil.float : (float) -> (int)`"; - let arguments = (ins - Torch_FloatType:$a - ); - let results = (outs - Torch_IntType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCeilFloatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenCeilFloatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; - let hasFolder = 1; -} - def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index e079d83e1..718c44c7a 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" @@ -78,25 +77,6 @@ public: }; } // namespace -namespace { -template -class ConvertAtenUnaryOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenOp op, - typename OpConversionPattern::OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = - this->getTypeConverter()->convertType(op->getResult(0).getType()); - Value result = rewriter.create(op.getLoc(), adaptor.a()); - rewriter.replaceOp( - op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType)); - return success(); - } -}; -} // namespace - namespace { // Lowers aten integer comparison ops. template @@ -202,7 +182,6 @@ public: registry.insert(); registry.insert(); registry.insert(); - registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -211,7 +190,7 @@ public: ConversionTarget target(*context); target.addLegalDialect(); + cf::ControlFlowDialect>(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); @@ -267,9 +246,6 @@ public: target.addIllegalOp(); patterns.add>( typeConverter, context); - target.addIllegalOp(); - patterns.add>( - typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c99bf720d..4438247d1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1545,16 +1545,6 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { return nullptr; } -// AtenCeilFloatOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenCeilFloatOp::fold(ArrayRef operands) { - double c; - if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) - return getI64IntegerAttr(getContext(), std::ceil(c)); - return nullptr; -} - //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f224316d3..e0d498d85 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -511,7 +511,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::eq.device : (Device, Device) -> (bool)") - emit("aten::ceil.float : (float) -> (int)", has_folder=True) # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index 684ff2e74..38a9d7fa8 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -11,9 +11,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== - class AddIntModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -24,19 +22,16 @@ class AddIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs) + int(rhs) + return int(lhs)+int(rhs) @register_test_case(module_factory=lambda: AddIntModule()) def AddIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) - + module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) # ============================================================================== - class SubIntModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -47,19 +42,16 @@ class SubIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs) - int(rhs) + return int(lhs)-int(rhs) @register_test_case(module_factory=lambda: SubIntModule()) def SubIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) - + module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) # ============================================================================== - class SubFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -70,19 +62,16 @@ class SubFloatModule(torch.nn.Module): ([], torch.float64, True), ]) def forward(self, lhs, rhs): - return float(lhs) - float(rhs) + return float(lhs)-float(rhs) @register_test_case(module_factory=lambda: SubFloatModule()) def SubFloatModule_basic(module, tu: TestUtils): module.forward(torch.rand(()).double(), torch.rand(()).double()) - # ============================================================================== - class MulIntModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -93,19 +82,16 @@ class MulIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs) * int(rhs) + return int(lhs)*int(rhs) @register_test_case(module_factory=lambda: MulIntModule()) def MulIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) - + module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) # ============================================================================== - class DivFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -116,33 +102,9 @@ class DivFloatModule(torch.nn.Module): ([], torch.float64, True), ]) def forward(self, lhs, rhs): - return float(lhs) / float(rhs) + return float(lhs)/float(rhs) @register_test_case(module_factory=lambda: DivFloatModule()) def DivFloatModule_basic(module, tu: TestUtils): module.forward(torch.rand(()).double(), torch.rand(()).double()) - - -# ============================================================================== - - -class CeilFloatModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.float64, True), - ]) - def forward(self, lhs, rhs): - sub = float(lhs) - float(rhs) - return torch.ops.aten.ceil(float(sub)) - - -@register_test_case(module_factory=lambda: CeilFloatModule()) -def CeilFloatModule_basic(module, tu: TestUtils): - module.forward(torch.rand(()).double(), torch.rand(()).double()) diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index a5a07151e..5b3ca1792 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -207,15 +207,3 @@ func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch. %0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool return %0 : !torch.bool } - -// CHECK-LABEL: func @torch.aten.ceil.float( -// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { -// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] -// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64 -// CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64 -// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]] -// CHECK: return %[[OUT]] : !torch.int -func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int { - %0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int - return %0 : !torch.int -} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 303789a83..afddc97c3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1191,21 +1191,3 @@ func @torch.aten.ge.float$different_value() -> !torch.bool { %2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool return %2 : !torch.bool } - -// CHECK-LABEL: func @torch.aten.ceil.float$fold_cst() -> !torch.int { -// CHECK: %[[CST2:.*]] = torch.constant.int 2 -// CHECK: return %[[CST2]] : !torch.int -func @torch.aten.ceil.float$fold_cst() -> !torch.int { - %float = torch.constant.float 1.5 - %1 = torch.aten.ceil.float %float : !torch.float -> !torch.int - return %1 : !torch.int -} - -// CHECK-LABEL: func @torch.aten.ceil.float$no_fold( -// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { -// CHECK: %[[RESULT:.*]] = torch.aten.ceil.float %[[ARG]] : !torch.float -> !torch.int -// CHECK: return %[[RESULT]] : !torch.int -func @torch.aten.ceil.float$no_fold(%arg0 : !torch.float) -> !torch.int { - %1 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int - return %1 : !torch.int -}