[torch] Implement stronger verifiers for non-value semantic ops (#2519)

Attempt to solve https://github.com/llvm/torch-mlir/issues/2490

Changes for Non Value Semantic Ops having the
`IsTrailingUnderscoreInplaceVariant` trait :
- AnyTorchTensorType -> Torch_NonValueTensorType
- AnyTorchOptionalTensorType -> AnyTorchOptionalNonValueTensorType
- AnyTorchListOfOptionalTensorType ->
AnyTorchListOfOptionalNonValueTensorType
- AnyTorchListOfTensorType -> AnyTorchListOfNonValueTensorType

Created three new tensor types for optional and list non value tensors.
pull/2526/head snapshot-20231022.999
Sarthak Gupta 2023-10-21 21:39:55 +05:30 committed by GitHub
parent 0acbb264d4
commit 7633619ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 266 additions and 248 deletions

File diff suppressed because it is too large Load Diff

View File

@ -366,6 +366,8 @@ class OptionalOf<Type type, string descr> :
def AnyTorchOptionalTensorType :
OptionalOf<AnyTorchTensorType, "Optional torch tensor type">;
def AnyTorchOptionalNonValueTensorType :
OptionalOf<Torch_NonValueTensorType, "Optional torch non value tensor type">;
def AnyTorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">;
def AnyTorchOptionalFloatType: OptionalOf<Torch_FloatType, "Optional torch float type">;
def AnyTorchOptionalBoolType:
@ -390,9 +392,14 @@ def AnyTorchListOfTorchFloatType : ListOf<[Torch_FloatType], "Float list type (f
def AnyTorchListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">;
def AnyTorchListOfTensorType:
ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">;
def AnyTorchListOfNonValueTensorType:
ListOf<[Torch_NonValueTensorType], "Any int list type (Tensor[])">;
def AnyTorchListOfOptionalTensorType :
ListOf<[AnyTorchOptionalTensorType],
"Any optional tensor list type (Tensor?[])">;
def AnyTorchListOfOptionalNonValueTensorType :
ListOf<[AnyTorchOptionalNonValueTensorType],
"Any optional tensor list type (Tensor?[])">;
def AnyTorchListOfOptionalIntType :
ListOf<[AnyTorchOptionalIntType],
"List of optional ints type (int?[])">;

View File

@ -52,12 +52,22 @@ TORCH_TYPE_TO_ODS_TYPE = {
"__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType",
}
TORCH_NON_VALUE_TYPE_TO_ODS_TYPE = {
"Tensor": "Torch_NonValueTensorType",
"Tensor?": "AnyTorchOptionalNonValueTensorType",
"Tensor?[]": "AnyTorchListOfOptionalNonValueTensorType",
"Tensor[]": "AnyTorchListOfNonValueTensorType",
}
def get_ods_type(type: str):
def get_ods_type(type: str, non_value: bool):
# TODO: Increase precision on dict type modeling.
if type.startswith("Dict("):
type = "Dict"
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
if non_value:
ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get(type) or TORCH_TYPE_TO_ODS_TYPE.get(type)
else:
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
if ods_type is None:
raise Exception(
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!")
@ -130,6 +140,7 @@ def raw_emit_op(operator: JitOperator,
p_td("]> {")
with emitter_td.indent():
summary = f"Generated op for `{operator.unique_key}`"
is_non_value_op = "IsTrailingUnderscoreInplaceVariant" in traits
p_td(f"let summary = {emitter_td.quote(summary)};")
p_td(f"let arguments = (ins")
with emitter_td.indent():
@ -137,7 +148,7 @@ def raw_emit_op(operator: JitOperator,
p_td("Variadic<AnyTorchType>:$operands")
else:
p_td(",\n".join([
f"""{get_ods_type(arg["type"])}:${arg["name"]}"""
f"""{get_ods_type(arg["type"], is_non_value_op)}:${arg["name"]}"""
for arg in operator.arguments
]))
p_td(");")
@ -147,7 +158,7 @@ def raw_emit_op(operator: JitOperator,
p_td("Variadic<AnyTorchType>:$results")
else:
p_td(",\n".join([
f"""{get_ods_type(ret["type"])}:${ret["name"] or generic_result_name(e)}"""
f"""{get_ods_type(ret["type"], is_non_value_op)}:${ret["name"] or generic_result_name(e)}"""
for e, ret in enumerate(operator.returns)
]))
p_td(");")