mirror of https://github.com/llvm/torch-mlir
add lowerings for AtenLtIntOp and AtenLeIntOp (#3061)
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>pull/3073/head
parent
7825e12483
commit
5f325749f9
|
@ -408,7 +408,8 @@ public:
|
||||||
patterns.add<ConvertAtenDimOp>(typeConverter, context);
|
patterns.add<ConvertAtenDimOp>(typeConverter, context);
|
||||||
target.addIllegalOp<RuntimeAssertOp>();
|
target.addIllegalOp<RuntimeAssertOp>();
|
||||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp>();
|
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp,
|
||||||
|
AtenLtIntOp, AtenLeIntOp>();
|
||||||
patterns
|
patterns
|
||||||
.add<ConvertAtenIntComparisonOp<AtenNeIntOp, arith::CmpIPredicate::ne>>(
|
.add<ConvertAtenIntComparisonOp<AtenNeIntOp, arith::CmpIPredicate::ne>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
|
@ -418,9 +419,15 @@ public:
|
||||||
patterns.add<
|
patterns.add<
|
||||||
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
|
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
|
patterns.add<
|
||||||
|
ConvertAtenIntComparisonOp<AtenLtIntOp, arith::CmpIPredicate::slt>>(
|
||||||
|
typeConverter, context);
|
||||||
patterns.add<
|
patterns.add<
|
||||||
ConvertAtenIntComparisonOp<AtenGeIntOp, arith::CmpIPredicate::sge>>(
|
ConvertAtenIntComparisonOp<AtenGeIntOp, arith::CmpIPredicate::sge>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
|
patterns.add<
|
||||||
|
ConvertAtenIntComparisonOp<AtenLeIntOp, arith::CmpIPredicate::sle>>(
|
||||||
|
typeConverter, context);
|
||||||
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp,
|
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp,
|
||||||
AtenGtFloatIntOp>();
|
AtenGtFloatIntOp>();
|
||||||
patterns.add<
|
patterns.add<
|
||||||
|
|
|
@ -79,6 +79,33 @@ func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo
|
||||||
return %0 : !torch.bool
|
return %0 : !torch.bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.lt.int(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||||
|
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
|
// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[LHS_I64]], %[[RHS_I64]] : i64
|
||||||
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
|
func.func @torch.aten.lt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
||||||
|
%0 = torch.aten.lt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.le.int(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||||
|
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
|
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[LHS_I64]], %[[RHS_I64]] : i64
|
||||||
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
|
func.func @torch.aten.le.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
||||||
|
%0 = torch.aten.le.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
// CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||||
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
|
|
Loading…
Reference in New Issue