[MLIR][TORCH] Add E2E support for aten.ge.float op

This commit adds lowering of `aten.ge.float` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/802/head
Vivek Khandelwal 2022-04-25 18:42:45 +05:30
parent f5b6c4b601
commit 564734b2d7
7 changed files with 117 additions and 0 deletions

View File

@ -6518,6 +6518,31 @@ def Torch_AtenGtFloatOp : Torch_Op<"aten.gt.float", [
let hasFolder = 1;
}
def Torch_AtenGeFloatOp : Torch_Op<"aten.ge.float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::ge.float : (float, float) -> (bool)`";
let arguments = (ins
Torch_FloatType:$a,
Torch_FloatType:$b
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGeFloatOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenGeFloatOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenLtFloatOp : Torch_Op<"aten.lt.float", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -93,6 +93,23 @@ public:
};
} // namespace
namespace {
// Lowers aten float comparison ops.
template <typename AtenOp, arith::CmpFPredicate Pred>
class ConvertAtenFloatComparisonOp : public OpConversionPattern<AtenOp> {
public:
using OpConversionPattern<AtenOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, Pred, adaptor.a(),
adaptor.b());
return success();
}
};
} // namespace
// Tensors with integer types need to be converted to signless integer
// element type. All tensors with element types other than integer can reuse
// existing elements attribute.
@ -192,6 +209,10 @@ public:
patterns.add<
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
typeConverter, context);
target.addIllegalOp<AtenGeFloatOp>();
patterns.add<
ConvertAtenFloatComparisonOp<AtenGeFloatOp, arith::CmpFPredicate::UGE>>(
typeConverter, context);
target.addIllegalOp<ValueTensorLiteralOp>();
patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context);

View File

@ -796,6 +796,15 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
[](double a, double b) { return a > b; });
}
//===----------------------------------------------------------------------===//
// AtenGeFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a >= b; });
}
//===----------------------------------------------------------------------===//
// AtenEqFloatOp
//===----------------------------------------------------------------------===//

View File

@ -492,6 +492,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::neg.float : (float) -> (float)")
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::ge.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float_int : (float, int) -> (bool)")
emit("aten::__and__.bool : (bool, bool) -> (bool)")

View File

@ -69,3 +69,22 @@ class GtIntModule(torch.nn.Module):
def GtIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
# ==============================================================================
class GeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
([], torch.float64, True),
])
def forward(self, lhs, rhs):
return float(lhs) >= float(rhs)
@register_test_case(module_factory=lambda: GeFloatModule())
def GeFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(()), torch.randn(()))

View File

@ -166,3 +166,16 @@ func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.aten.ge.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool {
%0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool
return %0 : !torch.bool
}

View File

@ -1162,3 +1162,32 @@ func @torch.aten.div.float$fold_cst_operands() -> !torch.float {
%0 = torch.aten.div.float %float4, %float2 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.aten.ge.float$same_operand(
// CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.ge.float$same_operand(%arg0: !torch.float) -> !torch.bool {
%2 = torch.aten.ge.float %arg0, %arg0: !torch.float, !torch.float -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ge.float$same_value() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.ge.float$same_value() -> !torch.bool {
%float4 = torch.constant.float 4.0
%float4_0 = torch.constant.float 4.0
%2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ge.float$different_value() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func @torch.aten.ge.float$different_value() -> !torch.bool {
%float4 = torch.constant.float 4.0
%float4_0 = torch.constant.float 5.0
%2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool
return %2 : !torch.bool
}