Add aten.ne.bool to make CI pass

pull/369/head
Yi Zhang 2021-10-21 11:50:01 -04:00
parent 7c47b9a0c8
commit abfaf8c577
4 changed files with 64 additions and 1 deletions

View File

@ -2382,6 +2382,22 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
}
def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::ne.bool : (bool, bool) -> (bool)`";
let arguments = (ins
Torch_BoolType:$a,
Torch_BoolType:$b
);
let results = (outs
Torch_BoolType:$result
);
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
let hasFolder = 1;
}
def Torch_Aten__Is__Op : Torch_Op<"aten.__is__", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -435,7 +435,23 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
// AtenLenTOp
// AtenNeBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
if (getOperand(0) == getOperand(1))
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
bool a, b;
if (!matchPattern(getOperand(0), m_TorchConstantBool(&a)))
return nullptr;
if (!matchPattern(getOperand(1), m_TorchConstantBool(&b)))
return nullptr;
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
}
//===----------------------------------------------------------------------===//
// AtenDimOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {

View File

@ -590,6 +590,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::neg.float : (float) -> (float)")
emit("aten::lt.float_int : (float, int) -> (bool)")
emit("aten::__and__.bool : (bool, bool) -> (bool)")
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)

View File

@ -32,6 +32,36 @@ func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torch.non
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ne.bool() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.ne.bool() -> !torch.bool {
%a = torch.constant.bool true
%b = torch.constant.bool false
%0 = torch.aten.ne.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ne.bool$same_operand(
// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func @torch.aten.ne.bool$same_operand(%arg0: !torch.bool) -> !torch.bool {
%0 = torch.aten.ne.bool %arg0, %arg0: !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ne.bool$different_operand(
// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.ne.bool %[[ARG0]], %[[FALSE]] : !torch.bool, !torch.bool -> !torch.bool
// CHECK: return %[[RET]] : !torch.bool
func @torch.aten.ne.bool$different_operand(%a: !torch.bool) -> !torch.bool {
%b = torch.constant.bool false
%0 = torch.aten.ne.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.size$canonicalize_to_list(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<!torch.int> {
// CHECK: %[[C2:.*]] = torch.constant.int 2