mirror of https://github.com/llvm/torch-mlir
[Torch] emit aten.__contains__.str_list and add folder (#3249)
parent
9f64748f97
commit
aed2cf3351
|
@ -13626,6 +13626,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchListOfTorchStringType:$l,
|
||||||
|
Torch_StringType:$item
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
|
def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl<bool> &bind_values) {
|
||||||
return detail::torch_list_of_constant_bools_op_binder(bind_values);
|
return detail::torch_list_of_constant_bools_op_binder(bind_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
/// Matches the constant strs stored in a `torch.ListConstruct`.
|
||||||
|
struct torch_list_of_constant_strs_op_binder {
|
||||||
|
SmallVectorImpl<std::string> &bind_values;
|
||||||
|
|
||||||
|
/// Creates a matcher instance that binds the value to bvs if match succeeds.
|
||||||
|
torch_list_of_constant_strs_op_binder(SmallVectorImpl<std::string> &bvs)
|
||||||
|
: bind_values(bvs) {}
|
||||||
|
|
||||||
|
bool match(Operation *op) {
|
||||||
|
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
|
||||||
|
if (!listConstruct)
|
||||||
|
return false;
|
||||||
|
for (Value value : listConstruct.getElements()) {
|
||||||
|
std::string str;
|
||||||
|
if (matchPattern(value, m_TorchConstantStr(str)))
|
||||||
|
bind_values.push_back(str);
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// Matches the constant strs stored in a `torch.prim.ListConstruct`.
|
||||||
|
inline detail::torch_list_of_constant_strs_op_binder
|
||||||
|
m_TorchListOfConstantStrs(SmallVectorImpl<std::string> &bind_values) {
|
||||||
|
return detail::torch_list_of_constant_strs_op_binder(bind_values);
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
/// Matches the expected tensor and dim from `torch.aten.size.int`.
|
/// Matches the expected tensor and dim from `torch.aten.size.int`.
|
||||||
struct torch_tensor_size_int_op_binder {
|
struct torch_tensor_size_int_op_binder {
|
||||||
|
|
|
@ -2385,6 +2385,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Aten__Contains__StrListOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
|
||||||
|
StringAttr item = dyn_cast<StringAttr>(adaptor.getItem());
|
||||||
|
if (!item)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
if (auto listConstruct = getL().getDefiningOp<Torch::PrimListConstructOp>()) {
|
||||||
|
if (isListPotentiallyMutated(listConstruct))
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
llvm::SmallVector<std::string> strs;
|
||||||
|
if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) {
|
||||||
|
for (const auto &str : strs) {
|
||||||
|
if (item.getValue().str() == str)
|
||||||
|
return getI1IntegerAttr(getContext(), true);
|
||||||
|
}
|
||||||
|
return getI1IntegerAttr(getContext(), false);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenLtIntOp
|
// AtenLtIntOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -974,6 +974,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::format : (...) -> (str)")
|
emit("aten::format : (...) -> (str)")
|
||||||
emit("aten::join : (str, str[]) -> (str)")
|
emit("aten::join : (str, str[]) -> (str)")
|
||||||
emit("aten::warn : (str, int) -> ()")
|
emit("aten::warn : (str, int) -> ()")
|
||||||
|
emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True)
|
||||||
|
|
||||||
# Type conversion ops.
|
# Type conversion ops.
|
||||||
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
|
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
|
||||||
|
|
|
@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool {
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.eq.str$same_operand(
|
// CHECK-LABEL: func.func @torch.aten.eq.str$same_operand(
|
||||||
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||||
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true
|
// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK-NEXT: return %[[F]] : !torch.bool
|
// CHECK-NEXT: return %[[TRUE]] : !torch.bool
|
||||||
func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||||
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||||
return %0 : !torch.bool
|
return %0 : !torch.bool
|
||||||
|
@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: return %[[FALSE]] : !torch.bool
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
%str4 = torch.constant.str "4"
|
%str4 = torch.constant.str "4"
|
||||||
%str5 = torch.constant.str "5"
|
%str5 = torch.constant.str "5"
|
||||||
|
@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
|
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
|
||||||
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||||
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false
|
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK-NEXT: return %[[F]] : !torch.bool
|
// CHECK-NEXT: return %[[FALSE]] : !torch.bool
|
||||||
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||||
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||||
return %0 : !torch.bool
|
return %0 : !torch.bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK: return %[[TRUE]] : !torch.bool
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||||
%str4 = torch.constant.str "4"
|
%str4 = torch.constant.str "4"
|
||||||
%str4_0 = torch.constant.str "4"
|
%str4_0 = torch.constant.str "4"
|
||||||
|
@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int {
|
||||||
return %2 : !torch.int
|
return %2 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
|
func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
|
||||||
|
%str = torch.constant.str "c"
|
||||||
|
%str_0 = torch.constant.str "b"
|
||||||
|
%str_1 = torch.constant.str "a"
|
||||||
|
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
|
||||||
|
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
|
||||||
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
|
||||||
|
%str = torch.constant.str "aa"
|
||||||
|
%str_0 = torch.constant.str "aa"
|
||||||
|
%str_1 = torch.constant.str "ccc"
|
||||||
|
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
|
||||||
|
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.__not__
|
// CHECK-LABEL: func.func @torch.aten.__not__
|
||||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: return %[[TRUE]] : !torch.bool
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
|
Loading…
Reference in New Issue