mirror of https://github.com/llvm/torch-mlir
[torch] Implement stronger verifiers for non-value semantic ops (#2519)
Attempt to solve https://github.com/llvm/torch-mlir/issues/2490 Changes for Non Value Semantic Ops having the `IsTrailingUnderscoreInplaceVariant` trait : - AnyTorchTensorType -> Torch_NonValueTensorType - AnyTorchOptionalTensorType -> AnyTorchOptionalNonValueTensorType - AnyTorchListOfOptionalTensorType -> AnyTorchListOfOptionalNonValueTensorType - AnyTorchListOfTensorType -> AnyTorchListOfNonValueTensorType Created three new tensor types for optional and list non value tensors.pull/2526/head snapshot-20231022.999
parent
0acbb264d4
commit
7633619ed2
File diff suppressed because it is too large
Load Diff
|
@ -366,6 +366,8 @@ class OptionalOf<Type type, string descr> :
|
|||
|
||||
def AnyTorchOptionalTensorType :
|
||||
OptionalOf<AnyTorchTensorType, "Optional torch tensor type">;
|
||||
def AnyTorchOptionalNonValueTensorType :
|
||||
OptionalOf<Torch_NonValueTensorType, "Optional torch non value tensor type">;
|
||||
def AnyTorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">;
|
||||
def AnyTorchOptionalFloatType: OptionalOf<Torch_FloatType, "Optional torch float type">;
|
||||
def AnyTorchOptionalBoolType:
|
||||
|
@ -390,9 +392,14 @@ def AnyTorchListOfTorchFloatType : ListOf<[Torch_FloatType], "Float list type (f
|
|||
def AnyTorchListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">;
|
||||
def AnyTorchListOfTensorType:
|
||||
ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">;
|
||||
def AnyTorchListOfNonValueTensorType:
|
||||
ListOf<[Torch_NonValueTensorType], "Any int list type (Tensor[])">;
|
||||
def AnyTorchListOfOptionalTensorType :
|
||||
ListOf<[AnyTorchOptionalTensorType],
|
||||
"Any optional tensor list type (Tensor?[])">;
|
||||
def AnyTorchListOfOptionalNonValueTensorType :
|
||||
ListOf<[AnyTorchOptionalNonValueTensorType],
|
||||
"Any optional tensor list type (Tensor?[])">;
|
||||
def AnyTorchListOfOptionalIntType :
|
||||
ListOf<[AnyTorchOptionalIntType],
|
||||
"List of optional ints type (int?[])">;
|
||||
|
|
|
@ -52,12 +52,22 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
|||
"__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType",
|
||||
}
|
||||
|
||||
TORCH_NON_VALUE_TYPE_TO_ODS_TYPE = {
|
||||
"Tensor": "Torch_NonValueTensorType",
|
||||
"Tensor?": "AnyTorchOptionalNonValueTensorType",
|
||||
"Tensor?[]": "AnyTorchListOfOptionalNonValueTensorType",
|
||||
"Tensor[]": "AnyTorchListOfNonValueTensorType",
|
||||
}
|
||||
|
||||
def get_ods_type(type: str):
|
||||
|
||||
def get_ods_type(type: str, non_value: bool):
|
||||
# TODO: Increase precision on dict type modeling.
|
||||
if type.startswith("Dict("):
|
||||
type = "Dict"
|
||||
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||
if non_value:
|
||||
ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get(type) or TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||
else:
|
||||
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||
if ods_type is None:
|
||||
raise Exception(
|
||||
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!")
|
||||
|
@ -130,6 +140,7 @@ def raw_emit_op(operator: JitOperator,
|
|||
p_td("]> {")
|
||||
with emitter_td.indent():
|
||||
summary = f"Generated op for `{operator.unique_key}`"
|
||||
is_non_value_op = "IsTrailingUnderscoreInplaceVariant" in traits
|
||||
p_td(f"let summary = {emitter_td.quote(summary)};")
|
||||
p_td(f"let arguments = (ins")
|
||||
with emitter_td.indent():
|
||||
|
@ -137,7 +148,7 @@ def raw_emit_op(operator: JitOperator,
|
|||
p_td("Variadic<AnyTorchType>:$operands")
|
||||
else:
|
||||
p_td(",\n".join([
|
||||
f"""{get_ods_type(arg["type"])}:${arg["name"]}"""
|
||||
f"""{get_ods_type(arg["type"], is_non_value_op)}:${arg["name"]}"""
|
||||
for arg in operator.arguments
|
||||
]))
|
||||
p_td(");")
|
||||
|
@ -147,7 +158,7 @@ def raw_emit_op(operator: JitOperator,
|
|||
p_td("Variadic<AnyTorchType>:$results")
|
||||
else:
|
||||
p_td(",\n".join([
|
||||
f"""{get_ods_type(ret["type"])}:${ret["name"] or generic_result_name(e)}"""
|
||||
f"""{get_ods_type(ret["type"], is_non_value_op)}:${ret["name"] or generic_result_name(e)}"""
|
||||
for e, ret in enumerate(operator.returns)
|
||||
]))
|
||||
p_td(");")
|
||||
|
|
Loading…
Reference in New Issue