[MLIR][TORCH] Add E2E support for aten.ge.int op

This commit adds lowering of `aten.ge.int` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1122/head
Vivek Khandelwal 2022-07-22 18:21:26 +05:30
parent 11a8901078
commit 7247c6a3a7
3 changed files with 40 additions and 1 deletions

View File

@ -323,7 +323,7 @@ public:
patterns.add<ConvertAtenIsFloatingPointOp>(typeConverter, context);
target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp>();
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp>();
patterns
.add<ConvertAtenIntComparisonOp<AtenNeIntOp, arith::CmpIPredicate::ne>>(
typeConverter, context);
@ -333,6 +333,9 @@ public:
patterns.add<
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
typeConverter, context);
patterns.add<
ConvertAtenIntComparisonOp<AtenGeIntOp, arith::CmpIPredicate::sge>>(
typeConverter, context);
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp,
AtenGtFloatIntOp>();
patterns.add<

View File

@ -81,6 +81,29 @@ def GtIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class GeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, lhs, rhs):
return torch.ops.aten.ge(int(lhs), int(rhs))
@register_test_case(module_factory=lambda: GeIntModule())
def GeIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
# ==============================================================================
class GeFloatModule(torch.nn.Module):
def __init__(self):

View File

@ -66,6 +66,19 @@ func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.ge.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[CMP:.*]] = arith.cmpi sge, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ge.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>