mirror of https://github.com/llvm/torch-mlir
Expand pytype coverage for torch_signature_ods_gen.py
parent
0b7c443256
commit
959c0a79cb
|
@ -6,3 +6,4 @@ build-mlir
|
|||
install-mlir
|
||||
__pycache__
|
||||
|
||||
.pytype
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"""Queries the pytorch op registry and generates ODS and CC sources for the ops.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, TextIO, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, TextIO, Sequence, Tuple, Union
|
||||
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
|
@ -16,10 +16,22 @@ import textwrap
|
|||
import traceback
|
||||
|
||||
# Note that this utility exists only in the c-extension.
|
||||
from _torch_mlir import get_registered_ops
|
||||
from _torch_mlir import get_registered_ops # pytype: disable=import-error
|
||||
|
||||
# A Dist[str, _] mapping 'aten::OpName' to:
|
||||
# - bool (e.g. {'is_mutable': False} )
|
||||
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
|
||||
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
|
||||
REG_OP_TYPE = Dict[str, Union[bool, Tuple[str], "SIGLIST_TYPE"]]
|
||||
# A List[Dict[str, _]] mapping attribute names to:
|
||||
# - str (e.g. {'name': 'dim'} )
|
||||
# - int (e.g. {'N': 1} )
|
||||
# - Dict[str, List[str]]
|
||||
# (e.g. {'alias_info': {'before': ['alias::a'], 'after': ['alias::a']}} )
|
||||
SIGLIST_TYPE = List[Dict[str, Union[str, int, Dict[str, List[str]]]]]
|
||||
|
||||
|
||||
def _create_argparse():
|
||||
def _create_argparse() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="generate_ods")
|
||||
parser.add_argument("--ods_td_file",
|
||||
required=True,
|
||||
|
@ -32,7 +44,7 @@ def _create_argparse():
|
|||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(args: argparse.Namespace):
|
||||
reg_ops = _load_ops_as_dict()
|
||||
if args.debug_op_reg_file:
|
||||
with open(args.debug_op_reg_file, "w") as debug_ops_file:
|
||||
|
@ -97,8 +109,8 @@ def generate_ops(g: "OpGenerator"):
|
|||
# These do return as None but are not coded optional in the registry :(
|
||||
override_return_types=["Tensor?", "Tensor?", "Tensor?"])
|
||||
|
||||
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)",
|
||||
"LogSoftmaxOp", "log_softmax")
|
||||
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)", "LogSoftmaxOp",
|
||||
"log_softmax")
|
||||
g.ordinary_immutable_op(
|
||||
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
|
||||
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
|
||||
|
@ -130,7 +142,7 @@ def generate_ops(g: "OpGenerator"):
|
|||
drop_arg_indices=[2])
|
||||
|
||||
|
||||
def dump_registered_ops(outfile, reg_ops_dict):
|
||||
def dump_registered_ops(outfile: TextIO, reg_ops_dict: Dict[str, REG_OP_TYPE]):
|
||||
for k in sorted(reg_ops_dict.keys()):
|
||||
attr_dict = reg_ops_dict[k]
|
||||
outfile.write(f"OP '{k}':\n")
|
||||
|
@ -141,24 +153,23 @@ def dump_registered_ops(outfile, reg_ops_dict):
|
|||
|
||||
class OpGenerator:
|
||||
|
||||
def __init__(self, reg_ops, ods_emitter, impl_emitter):
|
||||
def __init__(self, reg_ops: Dict[str, REG_OP_TYPE], ods_emitter: "OdsEmitter",
|
||||
impl_emitter: "CCImplEmitter"):
|
||||
super().__init__()
|
||||
self.reg_ops = reg_ops
|
||||
self.ods_emitter = ods_emitter
|
||||
self.impl_emitter = impl_emitter
|
||||
|
||||
def print_banner(self, text):
|
||||
def print_banner(self, text: str):
|
||||
seperator = f"// {'-' * 77}"
|
||||
for em in (self.ods_emitter, self.impl_emitter):
|
||||
em.print(
|
||||
"// -----------------------------------------------------------------------------"
|
||||
)
|
||||
em.print(seperator)
|
||||
em.print(f"// {text}")
|
||||
em.print(
|
||||
"// -----------------------------------------------------------------------------"
|
||||
)
|
||||
em.print(seperator)
|
||||
em.print("")
|
||||
|
||||
def define_op(self, kernel_sig, ods_name, op_name, **kwargs):
|
||||
def define_op(self, kernel_sig: str, ods_name: str, op_name: str,
|
||||
**kwargs) -> "InflightOpDef":
|
||||
return InflightOpDef(self,
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
|
@ -166,11 +177,11 @@ class OpGenerator:
|
|||
**kwargs)
|
||||
|
||||
def ordinary_binary_op(self,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
promote_trailing_out_tensor=True,
|
||||
traits=(),
|
||||
kernel_sig: str,
|
||||
ods_name: str,
|
||||
op_name: str,
|
||||
promote_trailing_out_tensor: bool = True,
|
||||
traits: Sequence[str] = (),
|
||||
**kwargs):
|
||||
""""Binary"-ops. These ops typically have:
|
||||
- '.Tensor' variant where the second arg is a Tensor
|
||||
|
@ -198,7 +209,8 @@ class OpGenerator:
|
|||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.arg_transforms(type_transforms={
|
||||
opdef.arg_transforms(
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
"Tensor:1": "AnyTorchImmutableTensor",
|
||||
"Scalar:1": "AnyTorchImmutableTensor",
|
||||
|
@ -207,21 +219,24 @@ class OpGenerator:
|
|||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
":1": ["kImmutableTensor", "kPromoteScalar"],
|
||||
})
|
||||
opdef.return_transforms(type_transforms={
|
||||
},
|
||||
)
|
||||
opdef.return_transforms(
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
},
|
||||
)
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_immutable_op(self,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
promote_trailing_out_tensor=True,
|
||||
traits=(),
|
||||
kernel_sig: str,
|
||||
ods_name: str,
|
||||
op_name: str,
|
||||
promote_trailing_out_tensor: bool = True,
|
||||
traits: Sequence[str] = (),
|
||||
**kwargs):
|
||||
""""An ordinary immutable-tensor based op."""
|
||||
opdef = self.define_op(
|
||||
|
@ -231,7 +246,8 @@ class OpGenerator:
|
|||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.transforms(type_transforms={
|
||||
opdef.transforms(
|
||||
type_transforms={
|
||||
"Tensor": "AnyTorchImmutableTensor",
|
||||
"Tensor?": "AnyTorchOptionalImmutableTensor",
|
||||
"int": "AnyTorchIntType",
|
||||
|
@ -242,10 +258,12 @@ class OpGenerator:
|
|||
flag_transforms={
|
||||
"Tensor": ["kImmutableTensor"],
|
||||
"Tensor?": ["kImmutableTensor"],
|
||||
})
|
||||
},
|
||||
)
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_inplace_op(self, kernel_sig, ods_name, op_name, **kwargs):
|
||||
def ordinary_inplace_op(self, kernel_sig: str, ods_name: str, op_name: str,
|
||||
**kwargs):
|
||||
"""In-place ops (ending in '_').
|
||||
|
||||
These ops take a mutable first argument and then standard immutable
|
||||
|
@ -256,7 +274,8 @@ class OpGenerator:
|
|||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
**kwargs)
|
||||
opdef.arg_transforms(type_transforms={
|
||||
opdef.arg_transforms(
|
||||
type_transforms={
|
||||
":0": "AnyTorchMutableTensor",
|
||||
"Tensor": "AnyTorchImmutableTensor",
|
||||
"Tensor?": "AnyTorchOptionalImmutableTensor",
|
||||
|
@ -269,24 +288,26 @@ class OpGenerator:
|
|||
":0": [],
|
||||
"Tensor": ["kImmutableTensor"],
|
||||
"Tensor?": ["kImmutableTensor"],
|
||||
})
|
||||
},
|
||||
)
|
||||
opdef.return_transforms(
|
||||
type_transforms={
|
||||
":0": "DROP_UNUSED", # Ignored because we clear the outs below.
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kDropReturnAndAliasArg0"],
|
||||
})
|
||||
},
|
||||
)
|
||||
opdef.map_signatures()
|
||||
opdef.ods_outs = [] # Clear the computed outs.
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_unary_op(self,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
promote_trailing_out_tensor=True,
|
||||
traits=(),
|
||||
kernel_sig: str,
|
||||
ods_name: str,
|
||||
op_name: str,
|
||||
promote_trailing_out_tensor: bool = True,
|
||||
traits: Sequence[str] = (),
|
||||
**kwargs):
|
||||
"""Unary ops.
|
||||
|
||||
|
@ -300,21 +321,25 @@ class OpGenerator:
|
|||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.arg_transforms(type_transforms={
|
||||
opdef.arg_transforms(
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
opdef.return_transforms(type_transforms={
|
||||
},
|
||||
)
|
||||
opdef.return_transforms(
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
},
|
||||
)
|
||||
opdef.emit()
|
||||
|
||||
def get_reg_record(self, kernel_sig):
|
||||
def get_reg_record(self, kernel_sig: str) -> REG_OP_TYPE:
|
||||
"""Gets the op-dict for a given registered op name.
|
||||
|
||||
Args:
|
||||
|
@ -336,7 +361,7 @@ class OpGenerator:
|
|||
|
||||
def _map_sigtypes(
|
||||
self,
|
||||
siglist: List[Dict],
|
||||
siglist: SIGLIST_TYPE,
|
||||
type_transforms: Dict[str, str],
|
||||
flag_transforms: Dict[str, List[str]],
|
||||
drop_indices: Sequence[int] = (),
|
||||
|
@ -405,16 +430,16 @@ class InflightOpDef:
|
|||
|
||||
def __init__(self,
|
||||
g: OpGenerator,
|
||||
kernel_sig,
|
||||
ods_name,
|
||||
op_name,
|
||||
traits=(),
|
||||
alias_kernel_names=(),
|
||||
promote_trailing_out_tensor=False,
|
||||
override_arg_types=None,
|
||||
override_return_types=None,
|
||||
drop_arg_indices=(),
|
||||
drop_return_indices=()):
|
||||
kernel_sig: str,
|
||||
ods_name: str,
|
||||
op_name: str,
|
||||
traits: Sequence[str] = (),
|
||||
alias_kernel_names: Sequence[str] = (),
|
||||
promote_trailing_out_tensor: bool = False,
|
||||
override_arg_types: Sequence[str] = None,
|
||||
override_return_types: Sequence[str] = None,
|
||||
drop_arg_indices: Sequence[int] = (),
|
||||
drop_return_indices: Sequence[int] = ()):
|
||||
super().__init__()
|
||||
self.g = g
|
||||
self.kernel_sig = kernel_sig
|
||||
|
@ -435,7 +460,7 @@ class InflightOpDef:
|
|||
self.arg_type_transforms = dict()
|
||||
self.arg_flag_transforms = dict()
|
||||
self.return_type_transforms = dict()
|
||||
self.return_flag_trasforms = dict()
|
||||
self.return_flag_transforms = dict()
|
||||
|
||||
# Signature mapping.
|
||||
self._sigs_mapped = False
|
||||
|
@ -450,14 +475,20 @@ class InflightOpDef:
|
|||
for line in traceback.format_list(self._traceback):
|
||||
sys.stderr.write(line)
|
||||
|
||||
def transforms(self, type_transforms=None, flag_transforms=None):
|
||||
def transforms(
|
||||
self,
|
||||
type_transforms: Dict[str, str] = None,
|
||||
flag_transforms: Dict[str, List[str]] = None) -> "InflightOpDef":
|
||||
self.arg_transforms(type_transforms=type_transforms,
|
||||
flag_transforms=flag_transforms)
|
||||
self.return_transforms(type_transforms=type_transforms,
|
||||
flag_transforms=flag_transforms)
|
||||
return self
|
||||
|
||||
def arg_transforms(self, type_transforms=None, flag_transforms=None):
|
||||
def arg_transforms(
|
||||
self,
|
||||
type_transforms: Dict[str, str] = None,
|
||||
flag_transforms: Dict[str, List[str]] = None) -> "InflightOpDef":
|
||||
"""Adds arg type and flag transforms dicts."""
|
||||
if type_transforms:
|
||||
self.arg_type_transforms.update(type_transforms)
|
||||
|
@ -465,15 +496,18 @@ class InflightOpDef:
|
|||
self.arg_flag_transforms.update(flag_transforms)
|
||||
return self
|
||||
|
||||
def return_transforms(self, type_transforms=None, flag_transforms=None):
|
||||
def return_transforms(
|
||||
self,
|
||||
type_transforms: Dict[str, str] = None,
|
||||
flag_transforms: Dict[str, List[str]] = None) -> "InflightOpDef":
|
||||
"""Adds return type and flag transform dicts."""
|
||||
if type_transforms:
|
||||
self.return_type_transforms.update(type_transforms)
|
||||
if flag_transforms:
|
||||
self.return_flag_trasforms.update(flag_transforms)
|
||||
self.return_flag_transforms.update(flag_transforms)
|
||||
return self
|
||||
|
||||
def map_signatures(self):
|
||||
def map_signatures(self) -> "InflightOpDef":
|
||||
assert not self._sigs_mapped, "Signatures already mapped"
|
||||
self._sigs_mapped = True
|
||||
self.ods_ins, self.arg_type_flags = self.g._map_sigtypes(
|
||||
|
@ -485,7 +519,7 @@ class InflightOpDef:
|
|||
self.ods_outs, self.return_type_flags = self.g._map_sigtypes(
|
||||
self.reg_record["returns"],
|
||||
type_transforms=self.return_type_transforms,
|
||||
flag_transforms=self.return_flag_trasforms,
|
||||
flag_transforms=self.return_flag_transforms,
|
||||
override_types=self.override_return_types,
|
||||
drop_indices=self.drop_return_indices)
|
||||
return self
|
||||
|
@ -519,23 +553,23 @@ class EmitterBase:
|
|||
self.indent_level = 0
|
||||
|
||||
@contextmanager
|
||||
def indent(self, level=1):
|
||||
def indent(self, level: int = 1):
|
||||
self.indent_level += level
|
||||
yield
|
||||
self.indent_level -= level
|
||||
assert self.indent_level >= 0, "Unbalanced indentation"
|
||||
|
||||
def print(self, s, *, end="\n", indent=True):
|
||||
def print(self, s: str, *, end: str = "\n", indent: bool = True):
|
||||
if indent and self.indent_level:
|
||||
self.out.write(self._INDENT * self.indent_level)
|
||||
self.out.write(s)
|
||||
self.out.write(end)
|
||||
|
||||
def quote(self, s: str):
|
||||
def quote(self, s: str) -> str:
|
||||
s = s.replace(r'"', r'\\"')
|
||||
return f'"{s}"'
|
||||
|
||||
def quote_multiline_docstring(self, s: str, indent_level: int = 0):
|
||||
def quote_multiline_docstring(self, s: str, indent_level: int = 0) -> str:
|
||||
# TODO: Possibly find a python module to markdown the docstring for better
|
||||
# document generation.
|
||||
# Unlikely to contain the delimitter and since just a docstring, be safe.
|
||||
|
@ -555,7 +589,7 @@ class OdsEmitter(EmitterBase):
|
|||
def emit_opdef(self,
|
||||
ods_def_name: str,
|
||||
mnemonic: str,
|
||||
reg_record: Dict,
|
||||
reg_record: REG_OP_TYPE,
|
||||
ods_ins: List[Tuple[str, str]],
|
||||
ods_outs: List[Tuple[str, str]],
|
||||
traits: Sequence[str] = (),
|
||||
|
@ -590,7 +624,7 @@ class OdsEmitter(EmitterBase):
|
|||
# Def last-line.
|
||||
self.print("}\n")
|
||||
|
||||
def _emit_dag_list_body(self, items):
|
||||
def _emit_dag_list_body(self, items: List[Tuple[str, str]]):
|
||||
"""Emits a dag of (name, type) pairs."""
|
||||
for index, (ods_name, ods_type) in enumerate(items):
|
||||
is_last = index == len(items) - 1
|
||||
|
@ -602,10 +636,10 @@ class CCImplEmitter(EmitterBase):
|
|||
|
||||
def emit_kernel_methods(self,
|
||||
ods_def_name: str,
|
||||
reg_record,
|
||||
reg_record: REG_OP_TYPE,
|
||||
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
promote_trailing_out_tensor=False,
|
||||
promote_trailing_out_tensor: bool = False,
|
||||
alias_kernel_names: Sequence[str] = ()):
|
||||
# getTorchKernelMetadata() method.
|
||||
self.print(
|
||||
|
@ -649,14 +683,15 @@ class CCImplEmitter(EmitterBase):
|
|||
self.print("return metadata;")
|
||||
self.print("}\n")
|
||||
|
||||
def _format_cpp_str_initlist(self, strings):
|
||||
def _format_cpp_str_initlist(self, strings: Sequence[str]) -> str:
|
||||
quoted = [self.quote(s) for s in strings]
|
||||
joined = ", ".join(quoted)
|
||||
return "{" + joined + "}"
|
||||
|
||||
def _format_cpp_kvc_initlist(self, const_name_lists):
|
||||
def _format_cpp_kvc_initlist(self,
|
||||
const_name_lists: List[List[Tuple[str]]]) -> str:
|
||||
|
||||
def or_flags(flag_names):
|
||||
def or_flags(flag_names: List[Tuple[str]]):
|
||||
if not flag_names:
|
||||
return "KVC::kNone"
|
||||
return "|".join([f"KVC::{n}" for n in flag_names])
|
||||
|
@ -666,18 +701,18 @@ class CCImplEmitter(EmitterBase):
|
|||
return "{" + joined + "}"
|
||||
|
||||
|
||||
def snakecase_to_camelcase(ident: str):
|
||||
def snakecase_to_camelcase(ident: str) -> str:
|
||||
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
||||
|
||||
|
||||
def _first_non_none(*args):
|
||||
def _first_non_none(*args) -> Union[None, Any]:
|
||||
for arg in args:
|
||||
if arg is not None:
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def _load_ops_as_dict():
|
||||
def _load_ops_as_dict() -> Dict[str, REG_OP_TYPE]:
|
||||
# Returns a list of dicts, each with a name that is a tuple of the form:
|
||||
# (kernel_signature, variant)
|
||||
# The kernel signature is a reified form of the argument type signature
|
||||
|
@ -692,8 +727,10 @@ def _load_ops_as_dict():
|
|||
return {reify_signature(reg_op): reg_op for reg_op in reg_ops_list}
|
||||
|
||||
|
||||
def _get_main_module_name():
|
||||
def _get_main_module_name() -> str:
|
||||
# pytype: disable=attribute-error
|
||||
return sys.modules["__main__"].__loader__.name
|
||||
# pytype: enable=attribute-error
|
||||
|
||||
|
||||
ODS_BANNER = "\n".join([
|
||||
|
|
Loading…
Reference in New Issue