mirror of https://github.com/llvm/torch-mlir
Add a new python script to auto-generate ATen op ODS definitions. (#43)
* Add a new python script to auto-generate ATen op ODS definitions. * There is still some work on some of the ops to annotate correct types. * The ODS is not actually included into the dialect yet, but I'd like to commit it so that we can track changes. * Will reconcile this with the ops produced by the existing script in a followup. Still need to do some more iteration to reach parity.pull/42/head
parent
d62f8227c2
commit
a74a98094b
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,199 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Generates ODS for a registry of ops."""
|
||||
|
||||
from typing import TextIO
|
||||
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
import importlib
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from .registry import *
|
||||
|
||||
_INDENT = " "
|
||||
|
||||
|
||||
class OdsEmitter:
|
||||
ods_prefix = "ATen_"
|
||||
ods_suffix = "Op"
|
||||
ods_value_template = "ATen_ImmutableTensorOp"
|
||||
ods_ref_template = "ATen_RefTensorOp"
|
||||
op_prefix = ""
|
||||
|
||||
def __init__(self, r: OpRegistry, out: TextIO):
|
||||
super().__init__()
|
||||
self.r = r
|
||||
self.out = out
|
||||
self.indent_level = 0
|
||||
|
||||
def emit_ods(self):
|
||||
for op_m in self.r.mappings:
|
||||
if isinstance(op_m, SimpleOpMapping):
|
||||
self._emit_simple_op_mapping(op_m)
|
||||
else:
|
||||
logging.warn(f"Unrecognized op mapping type: {op_m!r}")
|
||||
|
||||
def _emit_simple_op_mapping(self, op_m: SimpleOpMapping):
|
||||
identifier = (f"{self.ods_prefix}"
|
||||
f"{_snakecase_to_camelcase(op_m.mlir_operation_name)}"
|
||||
f"{self.ods_suffix}")
|
||||
traits = []
|
||||
|
||||
if op_m.is_outref_form:
|
||||
template_name = self.ods_ref_template
|
||||
summary = "See non-inplace op variant."
|
||||
description = ""
|
||||
else:
|
||||
template_name = self.ods_value_template
|
||||
summary, description = _split_docstring(op_m.op_f.__doc__)
|
||||
|
||||
if not op_m.is_outref_form:
|
||||
traits.append("NoSideEffect")
|
||||
self.print(f"def {identifier}: {template_name}"
|
||||
f"<{_quote(op_m.mlir_operation_name)}, ["
|
||||
f"{', '.join(traits)}"
|
||||
f"]> {{")
|
||||
|
||||
# Summary.
|
||||
with self.indent():
|
||||
self.print(f"let summary = {_quote(summary)};")
|
||||
|
||||
# Arguments.
|
||||
with self.indent():
|
||||
self.print("let arguments = (ins")
|
||||
with self.indent():
|
||||
operand_len = len(op_m.operand_map)
|
||||
for index, (_, value_spec) in enumerate(op_m.operand_map):
|
||||
is_last = index == operand_len - 1
|
||||
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
|
||||
end="\n" if is_last else ",\n")
|
||||
self.print(");")
|
||||
|
||||
# Results (omitted if an outref/inplace form).
|
||||
with self.indent():
|
||||
if op_m.is_outref_form:
|
||||
self.print("let results = (outs);")
|
||||
else:
|
||||
self.print("let results = (outs")
|
||||
with self.indent():
|
||||
result_len = len(op_m.result_map)
|
||||
for index, (_, value_spec) in enumerate(op_m.result_map):
|
||||
is_last = index == result_len - 1
|
||||
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
|
||||
end="\n" if is_last else ",\n")
|
||||
self.print(");")
|
||||
|
||||
# Description and extra class declarations.
|
||||
with self.indent():
|
||||
if description:
|
||||
quoted_description = _quote_multiline_docstring(
|
||||
description, indent_level=self.indent_level)
|
||||
self.print(f"let description = {quoted_description};")
|
||||
|
||||
self.print("}\n")
|
||||
|
||||
@contextmanager
|
||||
def indent(self, level=1):
|
||||
self.indent_level += level
|
||||
yield
|
||||
self.indent_level -= level
|
||||
assert self.indent_level >= 0, "Unbalanced indentation"
|
||||
|
||||
def print(self, s, *, end="\n", indent=True):
|
||||
if indent and self.indent_level:
|
||||
self.out.write(_INDENT * self.indent_level)
|
||||
self.out.write(s)
|
||||
self.out.write(end)
|
||||
|
||||
|
||||
def _snakecase_to_camelcase(ident: str):
|
||||
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
||||
|
||||
|
||||
def _quote(s: str):
|
||||
s = s.replace(r'"', r'\\"')
|
||||
return f'"{s}"'
|
||||
|
||||
|
||||
def _quote_multiline_docstring(s: str, indent_level: int = 0):
|
||||
# TODO: Possibly find a python module to markdown the docstring for better
|
||||
# document generation.
|
||||
# Unlikely to contain the delimitter and since just a docstring, be safe.
|
||||
s = s.replace("}]", "")
|
||||
# Strip each line.
|
||||
s = "\n".join([l.rstrip() for l in s.splitlines()])
|
||||
indent = _INDENT * indent_level
|
||||
s = textwrap.indent(s, indent + _INDENT)
|
||||
return "[{\n" + s + "\n" + indent + "}]"
|
||||
|
||||
|
||||
def _split_docstring(docstring: str):
|
||||
"""Splits the docstring into a summary and description."""
|
||||
lines = docstring.splitlines()
|
||||
# Skip leading blank lines.
|
||||
while lines and not lines[0]:
|
||||
lines = lines[1:]
|
||||
if len(lines) > 2:
|
||||
return lines[0], "\n".join(lines[2:])
|
||||
else:
|
||||
return lines[0]
|
||||
|
||||
|
||||
def main(args):
|
||||
r = OpRegistry()
|
||||
# Populate from modules that provide a populate() function.
|
||||
op_modules = [args.op_module]
|
||||
for m_name in op_modules:
|
||||
logging.info(f"Populating from module: {m_name}")
|
||||
m = importlib.import_module(m_name, package=__package__)
|
||||
f = getattr(m, "populate")
|
||||
f(r)
|
||||
|
||||
out = sys.stdout
|
||||
|
||||
# Write file header.
|
||||
module_name = sys.modules["__main__"].__loader__.name
|
||||
banner_lines = [
|
||||
"//===-------------------------------------------------------*- tablegen -*-===//",
|
||||
"//",
|
||||
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
|
||||
"// See https://llvm.org/LICENSE.txt for license information.",
|
||||
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
|
||||
"//",
|
||||
"// Operation summaries and descriptions were systematically derived from public",
|
||||
"// API docstrings and are licensed accordingly:",
|
||||
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"// This file is automatically generated. Please do not edit.",
|
||||
"// Generated via:",
|
||||
f"// python -m {module_name} {' '.join(sys.argv[1:])}",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
banner_lines = [l.strip() for l in banner_lines]
|
||||
out.write("\n".join(banner_lines))
|
||||
|
||||
emitter = OdsEmitter(r, out=out)
|
||||
emitter.emit_ods()
|
||||
|
||||
|
||||
def _create_argparse():
|
||||
parser = argparse.ArgumentParser(prog="generate_ods")
|
||||
parser.add_argument(
|
||||
"--op_module",
|
||||
default=".aten_ops",
|
||||
help="Name of a python module for populating the registry")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _create_argparse()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -22,7 +22,7 @@ Example usage (fully automatic discovery):
|
|||
alpha=ScalarValue()).with_outref_variant()
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
@ -92,6 +92,10 @@ class ValueSpec:
|
|||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def mlir_ods_predicate(self):
|
||||
return "AnyType"
|
||||
|
||||
def generate_example(self, index=0):
|
||||
"""Generates an example value."""
|
||||
raise NotImplementedError()
|
||||
|
@ -109,6 +113,10 @@ class TensorValue(ValueSpec):
|
|||
example_size = (2, 3, 7) # No significance.
|
||||
self.example_size = example_size
|
||||
|
||||
@property
|
||||
def mlir_ods_predicate(self):
|
||||
return "ATen_AnyTensor"
|
||||
|
||||
def generate_example(self, index=0):
|
||||
return torch.rand(*self.example_size)
|
||||
|
||||
|
@ -122,6 +130,10 @@ class TensorOutRef(ValueSpec):
|
|||
example_size = (2, 3, 7) # No significance.
|
||||
self.example_size = example_size
|
||||
|
||||
@property
|
||||
def mlir_ods_predicate(self):
|
||||
return "ATen_AnyRefTensor"
|
||||
|
||||
def generate_example(self, index=0):
|
||||
return torch.rand(*self.example_size)
|
||||
|
||||
|
@ -133,13 +145,22 @@ class ScalarValue(ValueSpec):
|
|||
super().__init__(name=name)
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def mlir_ods_predicate(self):
|
||||
return "ATen_AnyScalar"
|
||||
|
||||
def generate_example(self, index=0):
|
||||
if self.value is not None:
|
||||
return self.value
|
||||
return 1.0 + index # Generates a stable value.
|
||||
|
||||
|
||||
class SimpleOpMapping:
|
||||
class OpMapping:
|
||||
"""Base class for things purporting to map an operation."""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleOpMapping(OpMapping):
|
||||
"""Maps a PyTorch invocation to its MLIR representation."""
|
||||
|
||||
def __init__(self, op_f, *op_args, **op_kwargs):
|
||||
|
@ -228,10 +249,17 @@ class SimpleOpMapping:
|
|||
|
||||
def _set_default_mlir_operation_name(self):
|
||||
op_ns, op_name = self.op_kind.split("::", maxsplit=1)
|
||||
# Since these are emitted into the "aten" dialect namespace, alias them
|
||||
# to omit the prefix to distinguish from custom ops and others (which will
|
||||
# have a prefix).
|
||||
default_name = op_name if op_ns == "aten" else op_ns + "." + op_name
|
||||
if op_ns == "aten":
|
||||
default_name = op_name
|
||||
else:
|
||||
default_name = op_ns + "." + op_name
|
||||
|
||||
if self.is_outref_form:
|
||||
default_name += "_outref"
|
||||
default_name += ".inplace"
|
||||
self.mlir_operation_name = default_name
|
||||
|
||||
def _configure_from_example(self):
|
||||
|
@ -352,8 +380,12 @@ class OpRegistry:
|
|||
return m
|
||||
|
||||
@property
|
||||
def mappings(self):
|
||||
"""Returns the list of SimpleOpMappings."""
|
||||
def mappings(self) -> Sequence[OpMapping]:
|
||||
"""Returns the list of OpMapping.
|
||||
|
||||
Returns:
|
||||
Sequence of OpMapping concrete classes (most commonly SimpleOpMapping).
|
||||
"""
|
||||
self._finalize_pending()
|
||||
return self._mappings
|
||||
|
||||
|
|
Loading…
Reference in New Issue