Expand pytype coverage for torch_signature_ods_gen.py

pull/130/head
meadowlark@google.com 2020-11-24 12:42:56 -06:00 committed by Stella Laurenzo
parent 0b7c443256
commit 959c0a79cb
2 changed files with 159 additions and 121 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ build-mlir
install-mlir
__pycache__
.pytype

View File

@ -4,7 +4,7 @@
"""Queries the pytorch op registry and generates ODS and CC sources for the ops.
"""
from typing import Dict, List, Optional, TextIO, Sequence, Tuple
from typing import Any, Dict, List, Optional, TextIO, Sequence, Tuple, Union
import argparse
from contextlib import contextmanager
@ -16,10 +16,22 @@ import textwrap
import traceback
# Note that this utility exists only in the c-extension.
from _torch_mlir import get_registered_ops
from _torch_mlir import get_registered_ops # pytype: disable=import-error
# A Dist[str, _] mapping 'aten::OpName' to:
# - bool (e.g. {'is_mutable': False} )
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
REG_OP_TYPE = Dict[str, Union[bool, Tuple[str], "SIGLIST_TYPE"]]
# A List[Dict[str, _]] mapping attribute names to:
# - str (e.g. {'name': 'dim'} )
# - int (e.g. {'N': 1} )
# - Dict[str, List[str]]
# (e.g. {'alias_info': {'before': ['alias::a'], 'after': ['alias::a']}} )
SIGLIST_TYPE = List[Dict[str, Union[str, int, Dict[str, List[str]]]]]
def _create_argparse():
def _create_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="generate_ods")
parser.add_argument("--ods_td_file",
required=True,
@ -32,7 +44,7 @@ def _create_argparse():
return parser
def main(args):
def main(args: argparse.Namespace):
reg_ops = _load_ops_as_dict()
if args.debug_op_reg_file:
with open(args.debug_op_reg_file, "w") as debug_ops_file:
@ -97,8 +109,8 @@ def generate_ops(g: "OpGenerator"):
# These do return as None but are not coded optional in the registry :(
override_return_types=["Tensor?", "Tensor?", "Tensor?"])
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)",
"LogSoftmaxOp", "log_softmax")
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)", "LogSoftmaxOp",
"log_softmax")
g.ordinary_immutable_op(
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
@ -130,7 +142,7 @@ def generate_ops(g: "OpGenerator"):
drop_arg_indices=[2])
def dump_registered_ops(outfile, reg_ops_dict):
def dump_registered_ops(outfile: TextIO, reg_ops_dict: Dict[str, REG_OP_TYPE]):
for k in sorted(reg_ops_dict.keys()):
attr_dict = reg_ops_dict[k]
outfile.write(f"OP '{k}':\n")
@ -141,24 +153,23 @@ def dump_registered_ops(outfile, reg_ops_dict):
class OpGenerator:
def __init__(self, reg_ops, ods_emitter, impl_emitter):
def __init__(self, reg_ops: Dict[str, REG_OP_TYPE], ods_emitter: "OdsEmitter",
impl_emitter: "CCImplEmitter"):
super().__init__()
self.reg_ops = reg_ops
self.ods_emitter = ods_emitter
self.impl_emitter = impl_emitter
def print_banner(self, text):
def print_banner(self, text: str):
seperator = f"// {'-' * 77}"
for em in (self.ods_emitter, self.impl_emitter):
em.print(
"// -----------------------------------------------------------------------------"
)
em.print(seperator)
em.print(f"// {text}")
em.print(
"// -----------------------------------------------------------------------------"
)
em.print(seperator)
em.print("")
def define_op(self, kernel_sig, ods_name, op_name, **kwargs):
def define_op(self, kernel_sig: str, ods_name: str, op_name: str,
**kwargs) -> "InflightOpDef":
return InflightOpDef(self,
kernel_sig=kernel_sig,
ods_name=ods_name,
@ -166,11 +177,11 @@ class OpGenerator:
**kwargs)
def ordinary_binary_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
kernel_sig: str,
ods_name: str,
op_name: str,
promote_trailing_out_tensor: bool = True,
traits: Sequence[str] = (),
**kwargs):
""""Binary"-ops. These ops typically have:
- '.Tensor' variant where the second arg is a Tensor
@ -198,7 +209,8 @@ class OpGenerator:
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.arg_transforms(type_transforms={
opdef.arg_transforms(
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
"Tensor:1": "AnyTorchImmutableTensor",
"Scalar:1": "AnyTorchImmutableTensor",
@ -207,21 +219,24 @@ class OpGenerator:
flag_transforms={
":0": ["kImmutableTensor"],
":1": ["kImmutableTensor", "kPromoteScalar"],
})
opdef.return_transforms(type_transforms={
},
)
opdef.return_transforms(
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
},
)
opdef.emit()
def ordinary_immutable_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
kernel_sig: str,
ods_name: str,
op_name: str,
promote_trailing_out_tensor: bool = True,
traits: Sequence[str] = (),
**kwargs):
""""An ordinary immutable-tensor based op."""
opdef = self.define_op(
@ -231,7 +246,8 @@ class OpGenerator:
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.transforms(type_transforms={
opdef.transforms(
type_transforms={
"Tensor": "AnyTorchImmutableTensor",
"Tensor?": "AnyTorchOptionalImmutableTensor",
"int": "AnyTorchIntType",
@ -242,10 +258,12 @@ class OpGenerator:
flag_transforms={
"Tensor": ["kImmutableTensor"],
"Tensor?": ["kImmutableTensor"],
})
},
)
opdef.emit()
def ordinary_inplace_op(self, kernel_sig, ods_name, op_name, **kwargs):
def ordinary_inplace_op(self, kernel_sig: str, ods_name: str, op_name: str,
**kwargs):
"""In-place ops (ending in '_').
These ops take a mutable first argument and then standard immutable
@ -256,7 +274,8 @@ class OpGenerator:
ods_name=ods_name,
op_name=op_name,
**kwargs)
opdef.arg_transforms(type_transforms={
opdef.arg_transforms(
type_transforms={
":0": "AnyTorchMutableTensor",
"Tensor": "AnyTorchImmutableTensor",
"Tensor?": "AnyTorchOptionalImmutableTensor",
@ -269,24 +288,26 @@ class OpGenerator:
":0": [],
"Tensor": ["kImmutableTensor"],
"Tensor?": ["kImmutableTensor"],
})
},
)
opdef.return_transforms(
type_transforms={
":0": "DROP_UNUSED", # Ignored because we clear the outs below.
},
flag_transforms={
":0": ["kDropReturnAndAliasArg0"],
})
},
)
opdef.map_signatures()
opdef.ods_outs = [] # Clear the computed outs.
opdef.emit()
def ordinary_unary_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
kernel_sig: str,
ods_name: str,
op_name: str,
promote_trailing_out_tensor: bool = True,
traits: Sequence[str] = (),
**kwargs):
"""Unary ops.
@ -300,21 +321,25 @@ class OpGenerator:
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.arg_transforms(type_transforms={
opdef.arg_transforms(
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
opdef.return_transforms(type_transforms={
},
)
opdef.return_transforms(
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
},
)
opdef.emit()
def get_reg_record(self, kernel_sig):
def get_reg_record(self, kernel_sig: str) -> REG_OP_TYPE:
"""Gets the op-dict for a given registered op name.
Args:
@ -336,7 +361,7 @@ class OpGenerator:
def _map_sigtypes(
self,
siglist: List[Dict],
siglist: SIGLIST_TYPE,
type_transforms: Dict[str, str],
flag_transforms: Dict[str, List[str]],
drop_indices: Sequence[int] = (),
@ -405,16 +430,16 @@ class InflightOpDef:
def __init__(self,
g: OpGenerator,
kernel_sig,
ods_name,
op_name,
traits=(),
alias_kernel_names=(),
promote_trailing_out_tensor=False,
override_arg_types=None,
override_return_types=None,
drop_arg_indices=(),
drop_return_indices=()):
kernel_sig: str,
ods_name: str,
op_name: str,
traits: Sequence[str] = (),
alias_kernel_names: Sequence[str] = (),
promote_trailing_out_tensor: bool = False,
override_arg_types: Sequence[str] = None,
override_return_types: Sequence[str] = None,
drop_arg_indices: Sequence[int] = (),
drop_return_indices: Sequence[int] = ()):
super().__init__()
self.g = g
self.kernel_sig = kernel_sig
@ -435,7 +460,7 @@ class InflightOpDef:
self.arg_type_transforms = dict()
self.arg_flag_transforms = dict()
self.return_type_transforms = dict()
self.return_flag_trasforms = dict()
self.return_flag_transforms = dict()
# Signature mapping.
self._sigs_mapped = False
@ -450,14 +475,20 @@ class InflightOpDef:
for line in traceback.format_list(self._traceback):
sys.stderr.write(line)
def transforms(self, type_transforms=None, flag_transforms=None):
def transforms(
self,
type_transforms: Dict[str, str] = None,
flag_transforms: Dict[str, List[str]] = None) -> "InflightOpDef":
self.arg_transforms(type_transforms=type_transforms,
flag_transforms=flag_transforms)
self.return_transforms(type_transforms=type_transforms,
flag_transforms=flag_transforms)
return self
def arg_transforms(self, type_transforms=None, flag_transforms=None):
def arg_transforms(
self,
type_transforms: Dict[str, str] = None,
flag_transforms: Dict[str, List[str]] = None) -> "InflightOpDef":
"""Adds arg type and flag transforms dicts."""
if type_transforms:
self.arg_type_transforms.update(type_transforms)
@ -465,15 +496,18 @@ class InflightOpDef:
self.arg_flag_transforms.update(flag_transforms)
return self
def return_transforms(self, type_transforms=None, flag_transforms=None):
def return_transforms(
self,
type_transforms: Dict[str, str] = None,
flag_transforms: Dict[str, List[str]] = None) -> "InflightOpDef":
"""Adds return type and flag transform dicts."""
if type_transforms:
self.return_type_transforms.update(type_transforms)
if flag_transforms:
self.return_flag_trasforms.update(flag_transforms)
self.return_flag_transforms.update(flag_transforms)
return self
def map_signatures(self):
def map_signatures(self) -> "InflightOpDef":
assert not self._sigs_mapped, "Signatures already mapped"
self._sigs_mapped = True
self.ods_ins, self.arg_type_flags = self.g._map_sigtypes(
@ -485,7 +519,7 @@ class InflightOpDef:
self.ods_outs, self.return_type_flags = self.g._map_sigtypes(
self.reg_record["returns"],
type_transforms=self.return_type_transforms,
flag_transforms=self.return_flag_trasforms,
flag_transforms=self.return_flag_transforms,
override_types=self.override_return_types,
drop_indices=self.drop_return_indices)
return self
@ -519,23 +553,23 @@ class EmitterBase:
self.indent_level = 0
@contextmanager
def indent(self, level=1):
def indent(self, level: int = 1):
self.indent_level += level
yield
self.indent_level -= level
assert self.indent_level >= 0, "Unbalanced indentation"
def print(self, s, *, end="\n", indent=True):
def print(self, s: str, *, end: str = "\n", indent: bool = True):
if indent and self.indent_level:
self.out.write(self._INDENT * self.indent_level)
self.out.write(s)
self.out.write(end)
def quote(self, s: str):
def quote(self, s: str) -> str:
s = s.replace(r'"', r'\\"')
return f'"{s}"'
def quote_multiline_docstring(self, s: str, indent_level: int = 0):
def quote_multiline_docstring(self, s: str, indent_level: int = 0) -> str:
# TODO: Possibly find a python module to markdown the docstring for better
# document generation.
# Unlikely to contain the delimitter and since just a docstring, be safe.
@ -555,7 +589,7 @@ class OdsEmitter(EmitterBase):
def emit_opdef(self,
ods_def_name: str,
mnemonic: str,
reg_record: Dict,
reg_record: REG_OP_TYPE,
ods_ins: List[Tuple[str, str]],
ods_outs: List[Tuple[str, str]],
traits: Sequence[str] = (),
@ -590,7 +624,7 @@ class OdsEmitter(EmitterBase):
# Def last-line.
self.print("}\n")
def _emit_dag_list_body(self, items):
def _emit_dag_list_body(self, items: List[Tuple[str, str]]):
"""Emits a dag of (name, type) pairs."""
for index, (ods_name, ods_type) in enumerate(items):
is_last = index == len(items) - 1
@ -602,10 +636,10 @@ class CCImplEmitter(EmitterBase):
def emit_kernel_methods(self,
ods_def_name: str,
reg_record,
reg_record: REG_OP_TYPE,
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
promote_trailing_out_tensor=False,
promote_trailing_out_tensor: bool = False,
alias_kernel_names: Sequence[str] = ()):
# getTorchKernelMetadata() method.
self.print(
@ -649,14 +683,15 @@ class CCImplEmitter(EmitterBase):
self.print("return metadata;")
self.print("}\n")
def _format_cpp_str_initlist(self, strings):
def _format_cpp_str_initlist(self, strings: Sequence[str]) -> str:
quoted = [self.quote(s) for s in strings]
joined = ", ".join(quoted)
return "{" + joined + "}"
def _format_cpp_kvc_initlist(self, const_name_lists):
def _format_cpp_kvc_initlist(self,
const_name_lists: List[List[Tuple[str]]]) -> str:
def or_flags(flag_names):
def or_flags(flag_names: List[Tuple[str]]):
if not flag_names:
return "KVC::kNone"
return "|".join([f"KVC::{n}" for n in flag_names])
@ -666,18 +701,18 @@ class CCImplEmitter(EmitterBase):
return "{" + joined + "}"
def snakecase_to_camelcase(ident: str):
def snakecase_to_camelcase(ident: str) -> str:
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
def _first_non_none(*args):
def _first_non_none(*args) -> Union[None, Any]:
for arg in args:
if arg is not None:
return arg
return None
def _load_ops_as_dict():
def _load_ops_as_dict() -> Dict[str, REG_OP_TYPE]:
# Returns a list of dicts, each with a name that is a tuple of the form:
# (kernel_signature, variant)
# The kernel signature is a reified form of the argument type signature
@ -692,8 +727,10 @@ def _load_ops_as_dict():
return {reify_signature(reg_op): reg_op for reg_op in reg_ops_list}
def _get_main_module_name():
def _get_main_module_name() -> str:
# pytype: disable=attribute-error
return sys.modules["__main__"].__loader__.name
# pytype: enable=attribute-error
ODS_BANNER = "\n".join([