[Torch] Add support for Aten__Or__BoolOp (#3574)

pull/3576/head
yyp0 2024-07-31 17:23:53 +08:00 committed by GitHub
parent d3efab984b
commit f49b9c14f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 0 deletions

View File

@ -15548,6 +15548,31 @@ def Torch_Aten__Not__Op : Torch_Op<"aten.__not__", [
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_Aten__Or__BoolOp : Torch_Op<"aten.__or__.bool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__or__.bool : (bool, bool) -> (bool)`";
let arguments = (ins
Torch_BoolType:$a,
Torch_BoolType:$b
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__Or__BoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten__Or__BoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -732,6 +732,21 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
return IntegerAttr::get(IntegerType::get(getContext(), 1), !value); return IntegerAttr::get(IntegerType::get(getContext(), 1), !value);
} }
//===----------------------------------------------------------------------===//
// Aten__Or__Op
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) {
auto valueA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
auto valueB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
if (!valueA || !valueB) {
return nullptr;
}
return IntegerAttr::get(IntegerType::get(getContext(), 1),
valueA.getValue() | valueB.getValue());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenNeBoolOp // AtenNeBoolOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -855,6 +855,7 @@ STABLEHLO_PASS_SET = {
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AddIntModule_basic", "AddIntModule_basic",
"AliasModule_basic", "AliasModule_basic",
"TrueFalseOrBoolOpModule_basic",
"AllBoolFalseModule_basic", "AllBoolFalseModule_basic",
"AllBoolTrueModule_basic", "AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic", "AnyBoolFalseModule_basic",
@ -1576,6 +1577,7 @@ TOSA_PASS_SET = {
"AtenInstanceNormModule_basic", "AtenInstanceNormModule_basic",
"AtenToDeviceModule_basic", "AtenToDeviceModule_basic",
"Aten_CastFloatModule_basic", "Aten_CastFloatModule_basic",
"TrueFalseOrBoolOpModule_basic",
"BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic",
"BaddbmmDynamicModule_basic", "BaddbmmDynamicModule_basic",

View File

@ -1077,6 +1077,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__not__ : (bool) -> (bool)", has_folder=True) emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
emit("aten::__or__.bool : (bool, bool) -> (bool)", has_folder=True)
emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True)
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::_set_item.t : (t[], int, t) -> (t[])")

View File

@ -528,3 +528,21 @@ class AtenItemFpOpModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AtenItemFpOpModule()) @register_test_case(module_factory=lambda: AtenItemFpOpModule())
def AtenItemFpOpModule_basic(module, tu: TestUtils): def AtenItemFpOpModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1)) module.forward(tu.rand(1))
# ==============================================================================
class TrueFalseOrBoolOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([], torch.bool, True), ([], torch.bool, True)])
def forward(self, a, b):
return a | b
@register_test_case(module_factory=lambda: TrueFalseOrBoolOpModule())
def TrueFalseOrBoolOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=0, high=1).bool(), tu.randint(low=1, high=2).bool())