[MLIR][TORCH] Add E2E support for aten.ge.float_int op

This commit adds lowering of `aten.ge.float_int` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/802/head
Vivek Khandelwal 2022-04-25 19:07:43 +05:30
parent 564734b2d7
commit 1f102cc400
5 changed files with 68 additions and 4 deletions

View File

@ -6592,6 +6592,30 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [
}]; }];
} }
def Torch_AtenGeFloatIntOp : Torch_Op<"aten.ge.float_int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::ge.float_int : (float, int) -> (bool)`";
let arguments = (ins
Torch_FloatType:$a,
Torch_IntType:$b
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGeFloatIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenGeFloatIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
@ -94,7 +95,7 @@ public:
} // namespace } // namespace
namespace { namespace {
// Lowers aten float comparison ops. // Lowers aten float and float_int comparison ops.
template <typename AtenOp, arith::CmpFPredicate Pred> template <typename AtenOp, arith::CmpFPredicate Pred>
class ConvertAtenFloatComparisonOp : public OpConversionPattern<AtenOp> { class ConvertAtenFloatComparisonOp : public OpConversionPattern<AtenOp> {
public: public:
@ -103,8 +104,9 @@ public:
matchAndRewrite(AtenOp op, matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor, typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, Pred, adaptor.a(), Value lhs = adaptor.a(), rhs = adaptor.b();
adaptor.b()); rhs = convertScalarToDtype(rewriter, op.getLoc(), rhs, lhs.getType());
rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, Pred, lhs, rhs);
return success(); return success();
} }
}; };
@ -209,10 +211,13 @@ public:
patterns.add< patterns.add<
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>( ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
typeConverter, context); typeConverter, context);
target.addIllegalOp<AtenGeFloatOp>(); target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp>();
patterns.add< patterns.add<
ConvertAtenFloatComparisonOp<AtenGeFloatOp, arith::CmpFPredicate::UGE>>( ConvertAtenFloatComparisonOp<AtenGeFloatOp, arith::CmpFPredicate::UGE>>(
typeConverter, context); typeConverter, context);
patterns.add<ConvertAtenFloatComparisonOp<AtenGeFloatIntOp,
arith::CmpFPredicate::UGE>>(
typeConverter, context);
target.addIllegalOp<ValueTensorLiteralOp>(); target.addIllegalOp<ValueTensorLiteralOp>();
patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context); patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context);

View File

@ -495,6 +495,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::ge.float : (float, float) -> (bool)", has_folder=True) emit("aten::ge.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True) emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float_int : (float, int) -> (bool)") emit("aten::lt.float_int : (float, int) -> (bool)")
emit("aten::ge.float_int : (float, int) -> (bool)")
emit("aten::__and__.bool : (bool, bool) -> (bool)") emit("aten::__and__.bool : (bool, bool) -> (bool)")
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)

View File

@ -88,3 +88,23 @@ class GeFloatModule(torch.nn.Module):
@register_test_case(module_factory=lambda: GeFloatModule()) @register_test_case(module_factory=lambda: GeFloatModule())
def GeFloatModule_basic(module, tu: TestUtils): def GeFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(()), torch.randn(())) module.forward(torch.randn(()), torch.randn(()))
# ==============================================================================
class GeFloatIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
([], torch.int64, True),
])
def forward(self, lhs, rhs):
return float(lhs) >= int(rhs)
@register_test_case(module_factory=lambda: GeFloatIntModule())
def GeFloatIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(()), torch.randint(-100, 100, ()))

View File

@ -179,3 +179,17 @@ func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bo
%0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool %0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.ge.float_int(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool
}