mirror of https://github.com/llvm/torch-mlir
[Torch] Add support for Aten__Or__BoolOp (#3574)
parent
d3efab984b
commit
f49b9c14f1
|
@ -15548,6 +15548,31 @@ def Torch_Aten__Not__Op : Torch_Op<"aten.__not__", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten__Or__BoolOp : Torch_Op<"aten.__or__.bool", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::__or__.bool : (bool, bool) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_BoolType:$a,
|
||||||
|
Torch_BoolType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult Aten__Or__BoolOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void Aten__Or__BoolOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
|
def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -732,6 +732,21 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
|
||||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), !value);
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), !value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Aten__Or__Op
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto valueA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
||||||
|
auto valueB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
||||||
|
if (!valueA || !valueB) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return IntegerAttr::get(IntegerType::get(getContext(), 1),
|
||||||
|
valueA.getValue() | valueB.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenNeBoolOp
|
// AtenNeBoolOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -855,6 +855,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||||
"AddIntModule_basic",
|
"AddIntModule_basic",
|
||||||
"AliasModule_basic",
|
"AliasModule_basic",
|
||||||
|
"TrueFalseOrBoolOpModule_basic",
|
||||||
"AllBoolFalseModule_basic",
|
"AllBoolFalseModule_basic",
|
||||||
"AllBoolTrueModule_basic",
|
"AllBoolTrueModule_basic",
|
||||||
"AnyBoolFalseModule_basic",
|
"AnyBoolFalseModule_basic",
|
||||||
|
@ -1576,6 +1577,7 @@ TOSA_PASS_SET = {
|
||||||
"AtenInstanceNormModule_basic",
|
"AtenInstanceNormModule_basic",
|
||||||
"AtenToDeviceModule_basic",
|
"AtenToDeviceModule_basic",
|
||||||
"Aten_CastFloatModule_basic",
|
"Aten_CastFloatModule_basic",
|
||||||
|
"TrueFalseOrBoolOpModule_basic",
|
||||||
"BaddbmmBroadcast1DInputModule_basic",
|
"BaddbmmBroadcast1DInputModule_basic",
|
||||||
"BaddbmmBroadcast2DInputModule_basic",
|
"BaddbmmBroadcast2DInputModule_basic",
|
||||||
"BaddbmmDynamicModule_basic",
|
"BaddbmmDynamicModule_basic",
|
||||||
|
|
|
@ -1077,6 +1077,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||||
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
|
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
|
||||||
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
|
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
|
||||||
|
emit("aten::__or__.bool : (bool, bool) -> (bool)", has_folder=True)
|
||||||
emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True)
|
emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True)
|
||||||
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
|
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
|
||||||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||||
|
|
|
@ -528,3 +528,21 @@ class AtenItemFpOpModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: AtenItemFpOpModule())
|
@register_test_case(module_factory=lambda: AtenItemFpOpModule())
|
||||||
def AtenItemFpOpModule_basic(module, tu: TestUtils):
|
def AtenItemFpOpModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1))
|
module.forward(tu.rand(1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TrueFalseOrBoolOpModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([], torch.bool, True), ([], torch.bool, True)])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return a | b
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TrueFalseOrBoolOpModule())
|
||||||
|
def TrueFalseOrBoolOpModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(low=0, high=1).bool(), tu.randint(low=1, high=2).bool())
|
||||||
|
|
Loading…
Reference in New Issue