mirror of https://github.com/llvm/torch-mlir
Add aten.ne.bool to make CI pass
parent
7c47b9a0c8
commit
abfaf8c577
|
@ -2382,6 +2382,22 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
|||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::ne.bool : (bool, bool) -> (bool)`";
|
||||
let arguments = (ins
|
||||
Torch_BoolType:$a,
|
||||
Torch_BoolType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten__Is__Op : Torch_Op<"aten.__is__", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -435,7 +435,23 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLenTOp
|
||||
// AtenNeBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getOperand(0) == getOperand(1))
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
|
||||
|
||||
bool a, b;
|
||||
if (!matchPattern(getOperand(0), m_TorchConstantBool(&a)))
|
||||
return nullptr;
|
||||
if (!matchPattern(getOperand(1), m_TorchConstantBool(&b)))
|
||||
return nullptr;
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
|
|
@ -590,6 +590,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::neg.float : (float) -> (float)")
|
||||
emit("aten::lt.float_int : (float, int) -> (bool)")
|
||||
emit("aten::__and__.bool : (bool, bool) -> (bool)")
|
||||
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
|
||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
|
||||
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -32,6 +32,36 @@ func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torch.non
|
|||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ne.bool() -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
func @torch.aten.ne.bool() -> !torch.bool {
|
||||
%a = torch.constant.bool true
|
||||
%b = torch.constant.bool false
|
||||
%0 = torch.aten.ne.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ne.bool$same_operand(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
func @torch.aten.ne.bool$same_operand(%arg0: !torch.bool) -> !torch.bool {
|
||||
%0 = torch.aten.ne.bool %arg0, %arg0: !torch.bool, !torch.bool -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ne.bool$different_operand(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.ne.bool %[[ARG0]], %[[FALSE]] : !torch.bool, !torch.bool -> !torch.bool
|
||||
// CHECK: return %[[RET]] : !torch.bool
|
||||
func @torch.aten.ne.bool$different_operand(%a: !torch.bool) -> !torch.bool {
|
||||
%b = torch.constant.bool false
|
||||
%0 = torch.aten.ne.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.size$canonicalize_to_list(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<!torch.int> {
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
|
|
Loading…
Reference in New Issue