diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 77fcaa9b2..638dd5661 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6883,6 +6883,30 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [ }]; } +def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ceil.float : (float) -> (int)`"; + let arguments = (ins + Torch_FloatType:$a + ); + let results = (outs + Torch_IntType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeilFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCeilFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 718c44c7a..e079d83e1 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" @@ -77,6 +78,25 @@ public: }; } // namespace +namespace { +template +class ConvertAtenUnaryOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + Value result = rewriter.create(op.getLoc(), adaptor.a()); + rewriter.replaceOp( + op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType)); + return success(); + } +}; +} // namespace + namespace { // Lowers aten integer comparison ops. template @@ -182,6 +202,7 @@ public: registry.insert(); registry.insert(); registry.insert(); + registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -190,7 +211,7 @@ public: ConversionTarget target(*context); target.addLegalDialect(); + cf::ControlFlowDialect, math::MathDialect>(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); @@ -246,6 +267,9 @@ public: target.addIllegalOp(); patterns.add>( typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4438247d1..c99bf720d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1545,6 +1545,16 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { return nullptr; } +// AtenCeilFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCeilFloatOp::fold(ArrayRef operands) { + double c; + if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) + return getI64IntegerAttr(getContext(), std::ceil(c)); + return nullptr; +} + //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 63677c1e3..bd75a7451 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -509,6 +509,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::eq.device : (Device, Device) -> (bool)") + emit("aten::ceil.float : (float) -> (int)", has_folder=True) # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index 38a9d7fa8..684ff2e74 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -11,7 +11,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== + class AddIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -22,16 +24,19 @@ class AddIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs)+int(rhs) + return int(lhs) + int(rhs) @register_test_case(module_factory=lambda: AddIntModule()) def AddIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class SubIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -42,16 +47,19 @@ class SubIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs)-int(rhs) + return int(lhs) - int(rhs) @register_test_case(module_factory=lambda: SubIntModule()) def SubIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class SubFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -62,16 +70,19 @@ class SubFloatModule(torch.nn.Module): ([], torch.float64, True), ]) def forward(self, lhs, rhs): - return float(lhs)-float(rhs) + return float(lhs) - float(rhs) @register_test_case(module_factory=lambda: SubFloatModule()) def SubFloatModule_basic(module, tu: TestUtils): module.forward(torch.rand(()).double(), torch.rand(()).double()) + # ============================================================================== + class MulIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -82,16 +93,19 @@ class MulIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs)*int(rhs) + return int(lhs) * int(rhs) @register_test_case(module_factory=lambda: MulIntModule()) def MulIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class DivFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -102,9 +116,33 @@ class DivFloatModule(torch.nn.Module): ([], torch.float64, True), ]) def forward(self, lhs, rhs): - return float(lhs)/float(rhs) + return float(lhs) / float(rhs) @register_test_case(module_factory=lambda: DivFloatModule()) def DivFloatModule_basic(module, tu: TestUtils): module.forward(torch.rand(()).double(), torch.rand(()).double()) + + +# ============================================================================== + + +class CeilFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ]) + def forward(self, lhs, rhs): + sub = float(lhs) - float(rhs) + return torch.ops.aten.ceil(float(sub)) + + +@register_test_case(module_factory=lambda: CeilFloatModule()) +def CeilFloatModule_basic(module, tu: TestUtils): + module.forward(torch.rand(()).double(), torch.rand(()).double()) diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index 5b3ca1792..a5a07151e 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -207,3 +207,15 @@ func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch. %0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool return %0 : !torch.bool } + +// CHECK-LABEL: func @torch.aten.ceil.float( +// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { +// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] +// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64 +// CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64 +// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]] +// CHECK: return %[[OUT]] : !torch.int +func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int { + %0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int + return %0 : !torch.int +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index afddc97c3..303789a83 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1191,3 +1191,21 @@ func @torch.aten.ge.float$different_value() -> !torch.bool { %2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool return %2 : !torch.bool } + +// CHECK-LABEL: func @torch.aten.ceil.float$fold_cst() -> !torch.int { +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: return %[[CST2]] : !torch.int +func @torch.aten.ceil.float$fold_cst() -> !torch.int { + %float = torch.constant.float 1.5 + %1 = torch.aten.ceil.float %float : !torch.float -> !torch.int + return %1 : !torch.int +} + +// CHECK-LABEL: func @torch.aten.ceil.float$no_fold( +// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { +// CHECK: %[[RESULT:.*]] = torch.aten.ceil.float %[[ARG]] : !torch.float -> !torch.int +// CHECK: return %[[RESULT]] : !torch.int +func @torch.aten.ceil.float$no_fold(%arg0 : !torch.float) -> !torch.int { + %1 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int + return %1 : !torch.int +}