mirror of https://github.com/llvm/torch-mlir
Add a new python script to auto-generate ATen op ODS definitions. (#43)
* Add a new python script to auto-generate ATen op ODS definitions. * There is still some work on some of the ops to annotate correct types. * The ODS is not actually included into the dialect yet, but I'd like to commit it so that we can track changes. * Will reconcile this with the ops produced by the existing script in a followup. Still need to do some more iteration to reach parity.pull/42/head
parent
d62f8227c2
commit
a74a98094b
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,199 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
"""Generates ODS for a registry of ops."""
|
||||||
|
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from .registry import *
|
||||||
|
|
||||||
|
_INDENT = " "
|
||||||
|
|
||||||
|
|
||||||
|
class OdsEmitter:
|
||||||
|
ods_prefix = "ATen_"
|
||||||
|
ods_suffix = "Op"
|
||||||
|
ods_value_template = "ATen_ImmutableTensorOp"
|
||||||
|
ods_ref_template = "ATen_RefTensorOp"
|
||||||
|
op_prefix = ""
|
||||||
|
|
||||||
|
def __init__(self, r: OpRegistry, out: TextIO):
|
||||||
|
super().__init__()
|
||||||
|
self.r = r
|
||||||
|
self.out = out
|
||||||
|
self.indent_level = 0
|
||||||
|
|
||||||
|
def emit_ods(self):
|
||||||
|
for op_m in self.r.mappings:
|
||||||
|
if isinstance(op_m, SimpleOpMapping):
|
||||||
|
self._emit_simple_op_mapping(op_m)
|
||||||
|
else:
|
||||||
|
logging.warn(f"Unrecognized op mapping type: {op_m!r}")
|
||||||
|
|
||||||
|
def _emit_simple_op_mapping(self, op_m: SimpleOpMapping):
|
||||||
|
identifier = (f"{self.ods_prefix}"
|
||||||
|
f"{_snakecase_to_camelcase(op_m.mlir_operation_name)}"
|
||||||
|
f"{self.ods_suffix}")
|
||||||
|
traits = []
|
||||||
|
|
||||||
|
if op_m.is_outref_form:
|
||||||
|
template_name = self.ods_ref_template
|
||||||
|
summary = "See non-inplace op variant."
|
||||||
|
description = ""
|
||||||
|
else:
|
||||||
|
template_name = self.ods_value_template
|
||||||
|
summary, description = _split_docstring(op_m.op_f.__doc__)
|
||||||
|
|
||||||
|
if not op_m.is_outref_form:
|
||||||
|
traits.append("NoSideEffect")
|
||||||
|
self.print(f"def {identifier}: {template_name}"
|
||||||
|
f"<{_quote(op_m.mlir_operation_name)}, ["
|
||||||
|
f"{', '.join(traits)}"
|
||||||
|
f"]> {{")
|
||||||
|
|
||||||
|
# Summary.
|
||||||
|
with self.indent():
|
||||||
|
self.print(f"let summary = {_quote(summary)};")
|
||||||
|
|
||||||
|
# Arguments.
|
||||||
|
with self.indent():
|
||||||
|
self.print("let arguments = (ins")
|
||||||
|
with self.indent():
|
||||||
|
operand_len = len(op_m.operand_map)
|
||||||
|
for index, (_, value_spec) in enumerate(op_m.operand_map):
|
||||||
|
is_last = index == operand_len - 1
|
||||||
|
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
|
||||||
|
end="\n" if is_last else ",\n")
|
||||||
|
self.print(");")
|
||||||
|
|
||||||
|
# Results (omitted if an outref/inplace form).
|
||||||
|
with self.indent():
|
||||||
|
if op_m.is_outref_form:
|
||||||
|
self.print("let results = (outs);")
|
||||||
|
else:
|
||||||
|
self.print("let results = (outs")
|
||||||
|
with self.indent():
|
||||||
|
result_len = len(op_m.result_map)
|
||||||
|
for index, (_, value_spec) in enumerate(op_m.result_map):
|
||||||
|
is_last = index == result_len - 1
|
||||||
|
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
|
||||||
|
end="\n" if is_last else ",\n")
|
||||||
|
self.print(");")
|
||||||
|
|
||||||
|
# Description and extra class declarations.
|
||||||
|
with self.indent():
|
||||||
|
if description:
|
||||||
|
quoted_description = _quote_multiline_docstring(
|
||||||
|
description, indent_level=self.indent_level)
|
||||||
|
self.print(f"let description = {quoted_description};")
|
||||||
|
|
||||||
|
self.print("}\n")
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def indent(self, level=1):
|
||||||
|
self.indent_level += level
|
||||||
|
yield
|
||||||
|
self.indent_level -= level
|
||||||
|
assert self.indent_level >= 0, "Unbalanced indentation"
|
||||||
|
|
||||||
|
def print(self, s, *, end="\n", indent=True):
|
||||||
|
if indent and self.indent_level:
|
||||||
|
self.out.write(_INDENT * self.indent_level)
|
||||||
|
self.out.write(s)
|
||||||
|
self.out.write(end)
|
||||||
|
|
||||||
|
|
||||||
|
def _snakecase_to_camelcase(ident: str):
|
||||||
|
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
||||||
|
|
||||||
|
|
||||||
|
def _quote(s: str):
|
||||||
|
s = s.replace(r'"', r'\\"')
|
||||||
|
return f'"{s}"'
|
||||||
|
|
||||||
|
|
||||||
|
def _quote_multiline_docstring(s: str, indent_level: int = 0):
|
||||||
|
# TODO: Possibly find a python module to markdown the docstring for better
|
||||||
|
# document generation.
|
||||||
|
# Unlikely to contain the delimitter and since just a docstring, be safe.
|
||||||
|
s = s.replace("}]", "")
|
||||||
|
# Strip each line.
|
||||||
|
s = "\n".join([l.rstrip() for l in s.splitlines()])
|
||||||
|
indent = _INDENT * indent_level
|
||||||
|
s = textwrap.indent(s, indent + _INDENT)
|
||||||
|
return "[{\n" + s + "\n" + indent + "}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _split_docstring(docstring: str):
|
||||||
|
"""Splits the docstring into a summary and description."""
|
||||||
|
lines = docstring.splitlines()
|
||||||
|
# Skip leading blank lines.
|
||||||
|
while lines and not lines[0]:
|
||||||
|
lines = lines[1:]
|
||||||
|
if len(lines) > 2:
|
||||||
|
return lines[0], "\n".join(lines[2:])
|
||||||
|
else:
|
||||||
|
return lines[0]
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
r = OpRegistry()
|
||||||
|
# Populate from modules that provide a populate() function.
|
||||||
|
op_modules = [args.op_module]
|
||||||
|
for m_name in op_modules:
|
||||||
|
logging.info(f"Populating from module: {m_name}")
|
||||||
|
m = importlib.import_module(m_name, package=__package__)
|
||||||
|
f = getattr(m, "populate")
|
||||||
|
f(r)
|
||||||
|
|
||||||
|
out = sys.stdout
|
||||||
|
|
||||||
|
# Write file header.
|
||||||
|
module_name = sys.modules["__main__"].__loader__.name
|
||||||
|
banner_lines = [
|
||||||
|
"//===-------------------------------------------------------*- tablegen -*-===//",
|
||||||
|
"//",
|
||||||
|
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
|
||||||
|
"// See https://llvm.org/LICENSE.txt for license information.",
|
||||||
|
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
|
||||||
|
"//",
|
||||||
|
"// Operation summaries and descriptions were systematically derived from public",
|
||||||
|
"// API docstrings and are licensed accordingly:",
|
||||||
|
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
|
||||||
|
"//===----------------------------------------------------------------------===//",
|
||||||
|
"// This file is automatically generated. Please do not edit.",
|
||||||
|
"// Generated via:",
|
||||||
|
f"// python -m {module_name} {' '.join(sys.argv[1:])}",
|
||||||
|
"//===----------------------------------------------------------------------===//",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
banner_lines = [l.strip() for l in banner_lines]
|
||||||
|
out.write("\n".join(banner_lines))
|
||||||
|
|
||||||
|
emitter = OdsEmitter(r, out=out)
|
||||||
|
emitter.emit_ods()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_argparse():
|
||||||
|
parser = argparse.ArgumentParser(prog="generate_ods")
|
||||||
|
parser.add_argument(
|
||||||
|
"--op_module",
|
||||||
|
default=".aten_ops",
|
||||||
|
help="Name of a python module for populating the registry")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
parser = _create_argparse()
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
|
@ -22,7 +22,7 @@ Example usage (fully automatic discovery):
|
||||||
alpha=ScalarValue()).with_outref_variant()
|
alpha=ScalarValue()).with_outref_variant()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
@ -92,6 +92,10 @@ class ValueSpec:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mlir_ods_predicate(self):
|
||||||
|
return "AnyType"
|
||||||
|
|
||||||
def generate_example(self, index=0):
|
def generate_example(self, index=0):
|
||||||
"""Generates an example value."""
|
"""Generates an example value."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -109,6 +113,10 @@ class TensorValue(ValueSpec):
|
||||||
example_size = (2, 3, 7) # No significance.
|
example_size = (2, 3, 7) # No significance.
|
||||||
self.example_size = example_size
|
self.example_size = example_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mlir_ods_predicate(self):
|
||||||
|
return "ATen_AnyTensor"
|
||||||
|
|
||||||
def generate_example(self, index=0):
|
def generate_example(self, index=0):
|
||||||
return torch.rand(*self.example_size)
|
return torch.rand(*self.example_size)
|
||||||
|
|
||||||
|
@ -122,6 +130,10 @@ class TensorOutRef(ValueSpec):
|
||||||
example_size = (2, 3, 7) # No significance.
|
example_size = (2, 3, 7) # No significance.
|
||||||
self.example_size = example_size
|
self.example_size = example_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mlir_ods_predicate(self):
|
||||||
|
return "ATen_AnyRefTensor"
|
||||||
|
|
||||||
def generate_example(self, index=0):
|
def generate_example(self, index=0):
|
||||||
return torch.rand(*self.example_size)
|
return torch.rand(*self.example_size)
|
||||||
|
|
||||||
|
@ -133,13 +145,22 @@ class ScalarValue(ValueSpec):
|
||||||
super().__init__(name=name)
|
super().__init__(name=name)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mlir_ods_predicate(self):
|
||||||
|
return "ATen_AnyScalar"
|
||||||
|
|
||||||
def generate_example(self, index=0):
|
def generate_example(self, index=0):
|
||||||
if self.value is not None:
|
if self.value is not None:
|
||||||
return self.value
|
return self.value
|
||||||
return 1.0 + index # Generates a stable value.
|
return 1.0 + index # Generates a stable value.
|
||||||
|
|
||||||
|
|
||||||
class SimpleOpMapping:
|
class OpMapping:
|
||||||
|
"""Base class for things purporting to map an operation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleOpMapping(OpMapping):
|
||||||
"""Maps a PyTorch invocation to its MLIR representation."""
|
"""Maps a PyTorch invocation to its MLIR representation."""
|
||||||
|
|
||||||
def __init__(self, op_f, *op_args, **op_kwargs):
|
def __init__(self, op_f, *op_args, **op_kwargs):
|
||||||
|
@ -228,10 +249,17 @@ class SimpleOpMapping:
|
||||||
|
|
||||||
def _set_default_mlir_operation_name(self):
|
def _set_default_mlir_operation_name(self):
|
||||||
op_ns, op_name = self.op_kind.split("::", maxsplit=1)
|
op_ns, op_name = self.op_kind.split("::", maxsplit=1)
|
||||||
default_name = op_ns + "." + op_name
|
# Since these are emitted into the "aten" dialect namespace, alias them
|
||||||
|
# to omit the prefix to distinguish from custom ops and others (which will
|
||||||
|
# have a prefix).
|
||||||
|
default_name = op_name if op_ns == "aten" else op_ns + "." + op_name
|
||||||
|
if op_ns == "aten":
|
||||||
|
default_name = op_name
|
||||||
|
else:
|
||||||
|
default_name = op_ns + "." + op_name
|
||||||
|
|
||||||
if self.is_outref_form:
|
if self.is_outref_form:
|
||||||
default_name += "_outref"
|
default_name += ".inplace"
|
||||||
self.mlir_operation_name = default_name
|
self.mlir_operation_name = default_name
|
||||||
|
|
||||||
def _configure_from_example(self):
|
def _configure_from_example(self):
|
||||||
|
@ -352,8 +380,12 @@ class OpRegistry:
|
||||||
return m
|
return m
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mappings(self):
|
def mappings(self) -> Sequence[OpMapping]:
|
||||||
"""Returns the list of SimpleOpMappings."""
|
"""Returns the list of OpMapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sequence of OpMapping concrete classes (most commonly SimpleOpMapping).
|
||||||
|
"""
|
||||||
self._finalize_pending()
|
self._finalize_pending()
|
||||||
return self._mappings
|
return self._mappings
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue