From 96fabc0036fadc521b7d455eb35aaf11e14a93f2 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 25 Apr 2022 18:42:45 +0530 Subject: [PATCH] [MLIR][TORCH] E2E support for [ge|ceil].float, [ge|ne|gt].float_int op This commit adds lowering of `aten.ge.float`, `aten.ge.float_int`, `aten.ne.float_int`, `aten.gt.float_int` and `aten.ceil.float` op. This commit also fixes formatting for the file scalar.py and scalar_comparison.py. Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 121 ++++++++++++++++++ lib/Conversion/TorchToStd/TorchToStd.cpp | 59 ++++++++- lib/Dialect/Torch/IR/TorchOps.cpp | 19 +++ .../jit_ir/build_tools/torch_ods_gen.py | 5 + .../torch_mlir_e2e_test/test_suite/scalar.py | 54 ++++++-- .../test_suite/scalar_comparison.py | 99 ++++++++++++++ test/Conversion/TorchToStd/basic.mlir | 67 ++++++++++ test/Dialect/Torch/canonicalize.mlir | 47 +++++++ 8 files changed, 462 insertions(+), 9 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 212487a57..da80ad9a8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6678,6 +6678,31 @@ def Torch_AtenGtFloatOp : Torch_Op<"aten.gt.float", [ let hasFolder = 1; } +def Torch_AtenGeFloatOp : Torch_Op<"aten.ge.float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ge.float : (float, float) -> (bool)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_FloatType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGeFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGeFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenLtFloatOp : Torch_Op<"aten.lt.float", [ AllowsTypeRefinement, HasValueSemantics, @@ -6727,6 +6752,78 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [ }]; } +def Torch_AtenGeFloatIntOp : Torch_Op<"aten.ge.float_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ge.float_int : (float, int) -> (bool)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGeFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGeFloatIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenNeFloatIntOp : Torch_Op<"aten.ne.float_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ne.float_int : (float, int) -> (bool)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNeFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenNeFloatIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenGtFloatIntOp : Torch_Op<"aten.gt.float_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::gt.float_int : (float, int) -> (bool)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGtFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGtFloatIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ AllowsTypeRefinement, HasValueSemantics, @@ -6970,6 +7067,30 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [ }]; } +def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ceil.float : (float) -> (int)`"; + let arguments = (ins + Torch_FloatType:$a + ); + let results = (outs + Torch_IntType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeilFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCeilFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 2aef720f9..8455acf13 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -13,9 +13,11 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" @@ -76,6 +78,25 @@ public: }; } // namespace +namespace { +template +class ConvertAtenUnaryOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + Value result = rewriter.create(op.getLoc(), adaptor.a()); + rewriter.replaceOp( + op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType)); + return success(); + } +}; +} // namespace + namespace { // Lowers aten integer comparison ops. template @@ -93,6 +114,24 @@ public: }; } // namespace +namespace { +// Lowers aten float and float_int comparison ops. +template +class ConvertAtenFloatComparisonOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.a(), rhs = adaptor.b(); + rhs = convertScalarToDtype(rewriter, op.getLoc(), rhs, lhs.getType()); + rewriter.replaceOpWithNewOp(op, Pred, lhs, rhs); + return success(); + } +}; +} // namespace + // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. @@ -163,6 +202,7 @@ public: registry.insert(); registry.insert(); registry.insert(); + registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -171,7 +211,7 @@ public: ConversionTarget target(*context); target.addLegalDialect(); + cf::ControlFlowDialect, math::MathDialect>(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); @@ -192,6 +232,20 @@ public: patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); + target.addIllegalOp(); + patterns.add< + ConvertAtenFloatComparisonOp>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -217,6 +271,9 @@ public: target.addIllegalOp(); patterns.add>( typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 68aa4fe85..294871b18 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -857,6 +857,15 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef operands) { [](double a, double b) { return a > b; }); } +//===----------------------------------------------------------------------===// +// AtenGeFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenGeFloatOp::fold(ArrayRef operands) { + return floatComparatorFoldHelper(*this, + [](double a, double b) { return a >= b; }); +} + //===----------------------------------------------------------------------===// // AtenEqFloatOp //===----------------------------------------------------------------------===// @@ -1604,6 +1613,16 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { return nullptr; } +// AtenCeilFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCeilFloatOp::fold(ArrayRef operands) { + double c; + if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) + return getI64IntegerAttr(getContext(), std::ceil(c)); + return nullptr; +} + //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b10992641..adb70fa97 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -501,8 +501,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::neg.float : (float) -> (float)") emit("aten::eq.float : (float, float) -> (bool)", has_folder=True) emit("aten::gt.float : (float, float) -> (bool)", has_folder=True) + emit("aten::ge.float : (float, float) -> (bool)", has_folder=True) emit("aten::lt.float : (float, float) -> (bool)", has_folder=True) emit("aten::lt.float_int : (float, int) -> (bool)") + emit("aten::ge.float_int : (float, int) -> (bool)") + emit("aten::ne.float_int : (float, int) -> (bool)") + emit("aten::gt.float_int : (float, int) -> (bool)") emit("aten::__and__.bool : (bool, bool) -> (bool)") emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) @@ -515,6 +519,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::eq.device : (Device, Device) -> (bool)") + emit("aten::ceil.float : (float) -> (int)", has_folder=True) # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index 38a9d7fa8..684ff2e74 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -11,7 +11,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== + class AddIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -22,16 +24,19 @@ class AddIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs)+int(rhs) + return int(lhs) + int(rhs) @register_test_case(module_factory=lambda: AddIntModule()) def AddIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class SubIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -42,16 +47,19 @@ class SubIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs)-int(rhs) + return int(lhs) - int(rhs) @register_test_case(module_factory=lambda: SubIntModule()) def SubIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class SubFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -62,16 +70,19 @@ class SubFloatModule(torch.nn.Module): ([], torch.float64, True), ]) def forward(self, lhs, rhs): - return float(lhs)-float(rhs) + return float(lhs) - float(rhs) @register_test_case(module_factory=lambda: SubFloatModule()) def SubFloatModule_basic(module, tu: TestUtils): module.forward(torch.rand(()).double(), torch.rand(()).double()) + # ============================================================================== + class MulIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -82,16 +93,19 @@ class MulIntModule(torch.nn.Module): ([], torch.int64, True), ]) def forward(self, lhs, rhs): - return int(lhs)*int(rhs) + return int(lhs) * int(rhs) @register_test_case(module_factory=lambda: MulIntModule()) def MulIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class DivFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -102,9 +116,33 @@ class DivFloatModule(torch.nn.Module): ([], torch.float64, True), ]) def forward(self, lhs, rhs): - return float(lhs)/float(rhs) + return float(lhs) / float(rhs) @register_test_case(module_factory=lambda: DivFloatModule()) def DivFloatModule_basic(module, tu: TestUtils): module.forward(torch.rand(()).double(), torch.rand(()).double()) + + +# ============================================================================== + + +class CeilFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ]) + def forward(self, lhs, rhs): + sub = float(lhs) - float(rhs) + return torch.ops.aten.ceil(float(sub)) + + +@register_test_case(module_factory=lambda: CeilFloatModule()) +def CeilFloatModule_basic(module, tu: TestUtils): + module.forward(torch.rand(()).double(), torch.rand(()).double()) diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index 5bfeaf788..8a626d962 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -11,7 +11,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== + class NeIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -29,9 +31,12 @@ class NeIntModule(torch.nn.Module): def NeIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class EqIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -49,9 +54,12 @@ class EqIntModule(torch.nn.Module): def EqIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + # ============================================================================== + class GtIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -69,3 +77,94 @@ class GtIntModule(torch.nn.Module): def GtIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + +# ============================================================================== + + +class GeFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) >= float(rhs) + + +@register_test_case(module_factory=lambda: GeFloatModule()) +def GeFloatModule_basic(module, tu: TestUtils): + module.forward(torch.randn(()).double(), torch.randn(()).double()) + + +# ============================================================================== + + +class GeFloatIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) >= int(rhs) + + +@register_test_case(module_factory=lambda: GeFloatIntModule()) +def GeFloatIntModule_basic(module, tu: TestUtils): + module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + + +# ============================================================================== + + +class NeFloatIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) != int(rhs) + + +@register_test_case(module_factory=lambda: NeFloatIntModule()) +def NeFloatIntModule_basic(module, tu: TestUtils): + module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + + +# ============================================================================== + + +class GtFloatIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) > int(rhs) + + +@register_test_case(module_factory=lambda: GtFloatIntModule()) +def GtFloatIntModule_basic(module, tu: TestUtils): + module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index 09441d61b..f735972cc 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -166,3 +166,70 @@ func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f %0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float return %0 : !torch.float } + +// CHECK-LABEL: func @torch.aten.ge.float( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool { +// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool { + %0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ge.float_int( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 +// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ne.float_int( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 +// CHECK: %[[CMP:.*]] = arith.cmpf une, %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ceil.float( +// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { +// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] +// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64 +// CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64 +// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]] +// CHECK: return %[[OUT]] : !torch.int +func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int { + %0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int + return %0 : !torch.int +} + +// CHECK-LABEL: func @torch.aten.gt.float_int( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 +// CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.gt.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool + return %0 : !torch.bool +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index ea27cba57..7bfcb0f52 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1186,3 +1186,50 @@ func @torch.aten.to.dtype_layout$same_dtype(%arg0: !torch.tensor<[?,?],f32>) -> %0 = torch.aten.to.dtype_layout %arg0, %int6, %none, %none, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32> return %0 : !torch.tensor<[?,?],f32> } + +// CHECK-LABEL: func @torch.aten.ge.float$same_operand( +// CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.ge.float$same_operand(%arg0: !torch.float) -> !torch.bool { + %2 = torch.aten.ge.float %arg0, %arg0: !torch.float, !torch.float -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ge.float$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.ge.float$same_value() -> !torch.bool { + %float4 = torch.constant.float 4.0 + %float4_0 = torch.constant.float 4.0 + %2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ge.float$different_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.ge.float$different_value() -> !torch.bool { + %float4 = torch.constant.float 4.0 + %float4_0 = torch.constant.float 5.0 + %2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ceil.float$fold_cst() -> !torch.int { +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: return %[[CST2]] : !torch.int +func @torch.aten.ceil.float$fold_cst() -> !torch.int { + %float = torch.constant.float 1.5 + %1 = torch.aten.ceil.float %float : !torch.float -> !torch.int + return %1 : !torch.int +} + +// CHECK-LABEL: func @torch.aten.ceil.float$no_fold( +// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { +// CHECK: %[[RESULT:.*]] = torch.aten.ceil.float %[[ARG]] : !torch.float -> !torch.int +// CHECK: return %[[RESULT]] : !torch.int +func @torch.aten.ceil.float$no_fold(%arg0 : !torch.float) -> !torch.int { + %1 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int + return %1 : !torch.int +}