mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.ge.int op
This commit adds lowering of `aten.ge.int` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1122/head
parent
11a8901078
commit
7247c6a3a7
|
@ -323,7 +323,7 @@ public:
|
|||
patterns.add<ConvertAtenIsFloatingPointOp>(typeConverter, context);
|
||||
target.addIllegalOp<RuntimeAssertOp>();
|
||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp>();
|
||||
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp>();
|
||||
patterns
|
||||
.add<ConvertAtenIntComparisonOp<AtenNeIntOp, arith::CmpIPredicate::ne>>(
|
||||
typeConverter, context);
|
||||
|
@ -333,6 +333,9 @@ public:
|
|||
patterns.add<
|
||||
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
|
||||
typeConverter, context);
|
||||
patterns.add<
|
||||
ConvertAtenIntComparisonOp<AtenGeIntOp, arith::CmpIPredicate::sge>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp,
|
||||
AtenGtFloatIntOp>();
|
||||
patterns.add<
|
||||
|
|
|
@ -81,6 +81,29 @@ def GtIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class GeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.ops.aten.ge(int(lhs), int(rhs))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GeIntModule())
|
||||
def GeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GeFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -66,6 +66,19 @@ func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo
|
|||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ge.int(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||
// CHECK: %[[CMP:.*]] = arith.cmpi sge, %[[LHS_I64]], %[[RHS_I64]] : i64
|
||||
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||
func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
||||
%0 = torch.aten.ge.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
|
|
Loading…
Reference in New Issue