mirror of https://github.com/llvm/torch-mlir
Add folder for aten.gt/lt.float
parent
dcef4751f9
commit
9e7b6cab08
|
@ -3504,21 +3504,6 @@ def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
|
|||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::lt.float_int : (float, int) -> (bool)`";
|
||||
let arguments = (ins
|
||||
Torch_FloatType:$a,
|
||||
Torch_IntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -3535,6 +3520,53 @@ def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenGtFloatOp : Torch_Op<"aten.gt.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gt.float : (float, float) -> (bool)`";
|
||||
let arguments = (ins
|
||||
Torch_FloatType:$a,
|
||||
Torch_FloatType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenLtFloatOp : Torch_Op<"aten.lt.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::lt.float : (float, float) -> (bool)`";
|
||||
let arguments = (ins
|
||||
Torch_FloatType:$a,
|
||||
Torch_FloatType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::lt.float_int : (float, int) -> (bool)`";
|
||||
let arguments = (ins
|
||||
Torch_FloatType:$a,
|
||||
Torch_IntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -629,10 +629,52 @@ static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) {
|
|||
static_cast<int64_t>(value));
|
||||
}
|
||||
|
||||
using ConstantFloatComparator = std::function<bool(double, double)>;
|
||||
template <typename OpTy>
|
||||
static OpFoldResult
|
||||
floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
|
||||
if (op.getOperand(0) == op.getOperand(1))
|
||||
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
||||
|
||||
double lhs, rhs;
|
||||
if (!matchPattern(op.getOperand(0), m_TorchConstantFloat(&lhs)) ||
|
||||
!matchPattern(op.getOperand(1), m_TorchConstantFloat(&rhs)))
|
||||
return nullptr;
|
||||
|
||||
return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLtFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a < b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenGtFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a > b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenEqFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a == b; });
|
||||
}
|
||||
|
||||
using ConstantIntComparator = std::function<bool(int64_t, int64_t)>;
|
||||
template <typename OpTy>
|
||||
static OpFoldResult comparatorFoldHelper(OpTy op,
|
||||
ConstantIntComparator comparator) {
|
||||
static OpFoldResult intComparatorFoldHelper(OpTy op,
|
||||
ConstantIntComparator comparator) {
|
||||
if (op.getOperand(0) == op.getOperand(1))
|
||||
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
||||
|
||||
|
@ -649,8 +691,8 @@ static OpFoldResult comparatorFoldHelper(OpTy op,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return comparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a != b; });
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a != b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -658,22 +700,8 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return comparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a == b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenEqFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
double lhs, rhs;
|
||||
|
||||
if (!matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)) ||
|
||||
!matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)))
|
||||
return nullptr;
|
||||
|
||||
return getI1IntegerAttr(getContext(), lhs == rhs);
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a == b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -697,8 +725,8 @@ OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return comparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a < b; });
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a < b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -706,8 +734,8 @@ OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return comparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a <= b; });
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a <= b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -715,8 +743,8 @@ OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return comparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a > b; });
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a > b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -724,8 +752,8 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return comparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a >= b; });
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a >= b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -670,8 +670,10 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::add.float_int : (float, int) -> (float)")
|
||||
emit("aten::mul.float : (float, float) -> (float)")
|
||||
emit("aten::neg.float : (float) -> (float)")
|
||||
emit("aten::lt.float_int : (float, int) -> (bool)")
|
||||
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
||||
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
|
||||
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
|
||||
emit("aten::lt.float_int : (float, int) -> (bool)")
|
||||
emit("aten::__and__.bool : (bool, bool) -> (bool)")
|
||||
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
|
||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -258,6 +258,56 @@ func @torch.aten.ge.int$same_value() -> !torch.bool {
|
|||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.lt.float$evaluate_to_true() -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
func @torch.aten.lt.float$evaluate_to_true() -> !torch.bool {
|
||||
%float4 = torch.constant.float 4.0
|
||||
%float5 = torch.constant.float 5.0
|
||||
%2 = torch.aten.lt.float %float4, %float5 : !torch.float, !torch.float -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.lt.float$same_operand(
|
||||
// CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
func @torch.aten.lt.float$same_operand(%arg0: !torch.float) -> !torch.bool {
|
||||
%2 = torch.aten.lt.float %arg0, %arg0: !torch.float, !torch.float -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.lt.float$same_value() -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
func @torch.aten.lt.float$same_value() -> !torch.bool {
|
||||
%float4 = torch.constant.float 4.0
|
||||
%float4_0 = torch.constant.float 4.0
|
||||
%2 = torch.aten.lt.float %float4, %float4_0 : !torch.float, !torch.float -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.gt.float$evaluate_to_true() -> !torch.bool {
|
||||
// CHECK-NEXT: %[[T:.*]] = torch.constant.bool true
|
||||
// CHECK-NEXT: return %[[T]] : !torch.bool
|
||||
func @torch.aten.gt.float$evaluate_to_true() -> !torch.bool {
|
||||
%float2 = torch.constant.float 2.0
|
||||
%float4 = torch.constant.float 4.0
|
||||
%0 = torch.aten.gt.float %float4, %float2 : !torch.float, !torch.float -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.gt.float$evaluate_to_false() -> !torch.bool {
|
||||
// CHECK-NEXT: %[[T:.*]] = torch.constant.bool false
|
||||
// CHECK-NEXT: return %[[T]] : !torch.bool
|
||||
func @torch.aten.gt.float$evaluate_to_false() -> !torch.bool {
|
||||
%float2 = torch.constant.float 2.0
|
||||
%float4 = torch.constant.float 4.0
|
||||
%0 = torch.aten.gt.float %float2, %float4 : !torch.float, !torch.float -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.float$different_value() -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
|
|
Loading…
Reference in New Issue