mirror of https://github.com/llvm/torch-mlir
[Torch] emit aten.ne.str and add folder (#3242)
parent
944a6df611
commit
f173a06fa7
|
@ -13423,6 +13423,31 @@ def Torch_AtenEqStrOp : Torch_Op<"aten.eq.str", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNeStrOp : Torch_Op<"aten.ne.str", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::ne.str : (str, str) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_StringType:$a,
|
||||||
|
Torch_StringType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNeStrOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenNeStrOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLenStrOp : Torch_Op<"aten.len.str", [
|
def Torch_AtenLenStrOp : Torch_Op<"aten.len.str", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -2364,14 +2364,28 @@ OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
|
||||||
if (getOperand(0) == getOperand(1))
|
if (getOperand(0) == getOperand(1))
|
||||||
return getI1IntegerAttr(getContext(), true);
|
return getI1IntegerAttr(getContext(), true);
|
||||||
|
|
||||||
auto aStr = getA().getDefiningOp<ConstantStrOp>();
|
auto aStr = adaptor.getA();
|
||||||
auto bStr = getB().getDefiningOp<ConstantStrOp>();
|
auto bStr = adaptor.getB();
|
||||||
|
|
||||||
if (aStr && bStr)
|
if (aStr && bStr)
|
||||||
return getI1IntegerAttr(getContext(), aStr == bStr);
|
return getI1IntegerAttr(getContext(), aStr == bStr);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenNeStrOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
|
||||||
|
if (getOperand(0) == getOperand(1))
|
||||||
|
return getI1IntegerAttr(getContext(), false);
|
||||||
|
|
||||||
|
auto aStr = adaptor.getA();
|
||||||
|
auto bStr = adaptor.getB();
|
||||||
|
if (aStr && bStr)
|
||||||
|
return getI1IntegerAttr(getContext(), aStr != bStr);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenLtIntOp
|
// AtenLtIntOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -771,6 +771,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
# Str ops.
|
# Str ops.
|
||||||
emit("aten::add.str : (str, str) -> (str)")
|
emit("aten::add.str : (str, str) -> (str)")
|
||||||
emit("aten::eq.str : (str, str) -> (bool)", has_folder=True)
|
emit("aten::eq.str : (str, str) -> (bool)", has_folder=True)
|
||||||
|
emit("aten::ne.str : (str, str) -> (bool)", has_folder=True)
|
||||||
emit("aten::len.str : (str) -> (int)", has_folder=True)
|
emit("aten::len.str : (str) -> (int)", has_folder=True)
|
||||||
emit("aten::str : (t) -> (str)")
|
emit("aten::str : (t) -> (str)")
|
||||||
emit("aten::format : (...) -> (str)")
|
emit("aten::format : (...) -> (str)")
|
||||||
|
|
|
@ -521,6 +521,35 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool {
|
||||||
return %2 : !torch.bool
|
return %2 : !torch.bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
|
func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
|
%str4 = torch.constant.str "4"
|
||||||
|
%str5 = torch.constant.str "5"
|
||||||
|
%2 = torch.aten.ne.str %str4, %str5 : !torch.str, !torch.str -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
|
||||||
|
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||||
|
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false
|
||||||
|
// CHECK-NEXT: return %[[F]] : !torch.bool
|
||||||
|
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||||
|
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||||
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||||
|
%str4 = torch.constant.str "4"
|
||||||
|
%str4_0 = torch.constant.str "4"
|
||||||
|
%2 = torch.aten.ne.str %str4, %str4_0 : !torch.str, !torch.str -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.len.str() -> !torch.int {
|
// CHECK-LABEL: func.func @torch.aten.len.str() -> !torch.int {
|
||||||
// CHECK: %[[INT7:.*]] = torch.constant.int 7
|
// CHECK: %[[INT7:.*]] = torch.constant.int 7
|
||||||
// CHECK: return %[[INT7]] : !torch.int
|
// CHECK: return %[[INT7]] : !torch.int
|
||||||
|
|
Loading…
Reference in New Issue