add lowerings for AtenLtIntOp and AtenLeIntOp (#3061)

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
pull/3073/head
Xida Ren (Cedar) 2024-03-27 10:06:43 -07:00 committed by GitHub
parent 7825e12483
commit 5f325749f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 1 deletions

View File

@ -408,7 +408,8 @@ public:
patterns.add<ConvertAtenDimOp>(typeConverter, context); patterns.add<ConvertAtenDimOp>(typeConverter, context);
target.addIllegalOp<RuntimeAssertOp>(); target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context); patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp>(); target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp,
AtenLtIntOp, AtenLeIntOp>();
patterns patterns
.add<ConvertAtenIntComparisonOp<AtenNeIntOp, arith::CmpIPredicate::ne>>( .add<ConvertAtenIntComparisonOp<AtenNeIntOp, arith::CmpIPredicate::ne>>(
typeConverter, context); typeConverter, context);
@ -418,9 +419,15 @@ public:
patterns.add< patterns.add<
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>( ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
typeConverter, context); typeConverter, context);
patterns.add<
ConvertAtenIntComparisonOp<AtenLtIntOp, arith::CmpIPredicate::slt>>(
typeConverter, context);
patterns.add< patterns.add<
ConvertAtenIntComparisonOp<AtenGeIntOp, arith::CmpIPredicate::sge>>( ConvertAtenIntComparisonOp<AtenGeIntOp, arith::CmpIPredicate::sge>>(
typeConverter, context); typeConverter, context);
patterns.add<
ConvertAtenIntComparisonOp<AtenLeIntOp, arith::CmpIPredicate::sle>>(
typeConverter, context);
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp, target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp,
AtenGtFloatIntOp>(); AtenGtFloatIntOp>();
patterns.add< patterns.add<

View File

@ -79,6 +79,33 @@ func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func.func @torch.aten.lt.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func.func @torch.aten.lt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.lt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.le.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func.func @torch.aten.le.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.le.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>