Rewrite ATen ODS code generator to be based on new op registry and new signature recognition system.

* Deletes prior code generator from previous attempt (moved some of it into this one).
* Renames old generated tablegen source to "Legacy".
* Generates ODS and import rules for most binary and unary arithmetic ops.
* Removes old generated ops and integration tests that were testing details of the prior setup.
pull/98/head
Stella Laurenzo 2020-10-27 18:13:23 -07:00
parent 94ea6f7c92
commit c08935a418
16 changed files with 2259 additions and 1717 deletions

View File

@ -0,0 +1,14 @@
#!/bin/bash
# Updates the ATen dialect generated code from the PyTorch op registry.
# Requires that the project has been built and that PyTorch support is enabled.
set -e
src_dir="$(realpath $(dirname $0)/..)"
build_dir="$(realpath "${NPCOMP_BUILD_DIR:-$src_dir/build}")"
aten_dir="${src_dir}/include/npcomp/Dialect/ATen/IR"
export PYTHONPATH="${build_dir}/python"
python -m torch_mlir_utils.codegen.torch_signature_ods_gen \
--ods_td_file="${aten_dir}/GeneratedATenOps.td" \
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc"

View File

@ -1,202 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Populates an op registry for ATen ops.
Typically callers will import and use the 'populate' function to add known
ops to the OpRegistry. When run interactively as a main module, it simply
prints all registered ops.
"""
from .registry import *
import torch
import torch.nn.functional as F
def populate(r: OpRegistry):
# Unary pointwise ops (ordinary that take out refs).
for f in [
torch.abs, torch.acos, torch.angle, torch.asin, torch.atan, torch.ceil,
torch.conj, torch.cos, torch.cosh, torch.digamma, torch.erf, torch.erfc,
torch.erfinv, torch.exp, torch.expm1, torch.floor, torch.frac,
torch.lgamma, torch.log, torch.log10, torch.log1p, torch.log2, torch.neg,
torch.reciprocal, torch.round, torch.rsqrt, torch.sigmoid, torch.sign,
torch.sin, torch.sinh, torch.sqrt, torch.tan, torch.tanh, torch.trunc
]:
r.op(f, TensorValue("input")).with_outref_variant()
# Unary pointwise ops (that do not take out refs).
for f in [torch.relu]:
r.op(f, TensorValue("input"))
# Binary pointwise ops.
r.op(torch.add,
TensorValue("input"),
TensorValue("other"),
alpha=LiteralValue()).with_outref_variant()
r.op(torch.atan2, TensorValue("input"),
TensorValue("other")).with_outref_variant()
r.op(torch.div, TensorValue("input"),
TensorValue("other")).with_outref_variant()
r.op(torch.floor_divide, TensorValue("input"),
TensorValue("other")).with_outref_variant()
r.op(torch.mul, TensorValue("input"),
TensorValue("other")).with_outref_variant()
r.op(torch.remainder, TensorValue("input"),
TensorValue("other")).with_outref_variant()
r.op(torch.true_divide, TensorValue("dividend"),
TensorValue("divisor")).with_outref_variant()
# Aggregate operations.
# TODO: Support the optional dtype= parameter.
r.op(torch.cumsum, TensorValue("input", example_size=(10, 3)),
LiteralValue("dim", value=1)).with_outref_variant()
r.op(
torch.mean,
TensorValue("input", example_size=(10, 3)), #
LiteralValue("dim", value=[0], mlir_ods_predicate="ATen_IntList"), #
LiteralValue("keep_dim",
value=False,
mlir_ods_predicate="ATen_BoolScalar") #
).with_outref_variant()
r.op(
torch.sum,
TensorValue("input", example_size=(10, 3)), #
LiteralValue("dim", value=[0], mlir_ods_predicate="ATen_IntList"), #
LiteralValue("keep_dim",
value=False,
mlir_ods_predicate="ATen_BoolScalar") #
).with_outref_variant()
# Gather.
(r.op(
torch.gather, #
TensorValue("input", example_size=(2, 2)), #
dim=LiteralValue(value=1, mlir_ods_predicate="ATen_IntScalar"), #
index=LiteralValue(value=torch.tensor([[0, 0], [1, 0]]),
mlir_ods_predicate="ATen_AnyTensor"), #
sparse_grad=LiteralValue(value=False,
mlir_ods_predicate="ATen_BoolScalar") #
).with_outref_variant())
# Non-view methods on Tensor.
(r.op((TensorValue(name="input", example_size=(3, 1)), "@T")))
# BLAS and LAPACK ops.
r.op(torch.addmm,
TensorValue("input", example_size=(2, 3)),
TensorValue("mat1", example_size=(2, 3)),
TensorValue("mat2", example_size=(3, 3)),
beta=LiteralValue(),
alpha=LiteralValue()).with_outref_variant()
r.op(torch.dot, TensorValue("input", example_size=(10,)),
TensorValue("tensor", example_size=(10,)))
r.op(torch.matmul, TensorValue("input", example_size=(10, 3, 4)),
TensorValue("other", example_size=(4, 5))).with_outref_variant()
r.op(torch.mm, TensorValue("input", example_size=(3, 4)),
TensorValue("mat2", example_size=(4, 6))).with_outref_variant()
# NN Functional.
# Note that _convolution is a special case and is manually coded.
r.op(
F.hardtanh,
TensorValue("input"), #
min_val=LiteralValue(value=-1.0,
mlir_ods_predicate="ATen_FloatScalar"), #
max_val=LiteralValue(value=1.0, mlir_ods_predicate="ATen_FloatScalar") #
)
r.op(F.avg_pool1d,
TensorValue("input", example_size=(1, 1, 7)),
kernel_size=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"),
stride=LiteralValue(value=[5], mlir_ods_predicate="ATen_IntList"),
padding=LiteralValue(value=[1], mlir_ods_predicate="ATen_IntList"),
ceil_mode=LiteralValue(value=True, mlir_ods_predicate="ATen_BoolScalar"),
count_include_pad=LiteralValue(value=False,
mlir_ods_predicate="ATen_BoolScalar"))
# MaxPool1D is split into two ops based on whether return_indices is True:
# aten::max_pool1d -> tensor<float>
# aten::max_pool1d_with_indices -> (tensor<float>, tensor<long>)
# Both have odd signatures and are hand-mapped.
# TODO: Implement max_pool1d(..., with_indices=True)
(r.op(F.max_pool1d,
TensorValue("input", example_size=(1, 1, 7)),
kernel_size=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"),
stride=LiteralValue(value=[5], mlir_ods_predicate="ATen_IntList"),
padding=LiteralValue(value=[1], mlir_ods_predicate="ATen_IntList"),
dilation=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"),
ceil_mode=LiteralValue(value=True,
mlir_ods_predicate="ATen_BoolScalar"),
return_indices=LiteralValue(value=False)) #
.with_torch_op_kind("aten::max_pool1d") #
.with_operand_map("input", "kernel_size", "stride", "padding", "dilation",
"ceil_mode") #
)
# View ops.
# TODO: All of these need special analysis and should be parameterized
# on mutable tensors and have a proper design thought through. For now,
# even having them in the inventory (badly) increases visibility.
(r.op(torch.as_strided,
TensorValue("input"),
size=LiteralValue(value=[2, 2], mlir_ods_predicate="ATen_IntList"),
stride=LiteralValue(value=[1, 2], mlir_ods_predicate="ATen_IntList"),
storage_offset=LiteralValue(value=4,
mlir_ods_predicate="ATen_IntScalar")) #
.with_append_description(r"""
MLIR Specific Notes
-------------------
In PyTorch proper, this op creates a view that may internally alias. And
have explicit warnings about avoiding inplace updates on such a
view (without first cloning). For the moment, this op is formulated with
value semantics that imply a copy instead of a view, and it is expected
that any sharing can be recovered later by the compiler. The warning
about not in-place updating of such a result should be treated as UB
when compiled.
"""))
(r.op((TensorValue(name="input", example_size=(3, 1)), "expand"),
LiteralValue("sizes", value=torch.Size([3, 4]))) #
.with_operand_map("input", "sizes",
LiteralValue("implicit",
mlir_ods_predicate="ATen_BoolScalar")) #
.with_append_description(r"""
MLIR Specific Notes
-------------------
See notes for the 'as_strided' op.
"""))
(r.op(
torch.squeeze,
TensorValue("input"), #
LiteralValue("dim", value=1, mlir_ods_predicate="ATen_IntScalar") #
).with_append_description(r"""
MLIR Specific Notes
-------------------
See notes for the 'as_strided' op.
"""))
(r.op((TensorValue(name="input", example_size=(3, 1)), "view"),
LiteralValue(name="size",
value=[3, 1],
mlir_ods_predicate="ATen_IntList"))
.with_append_description(r"""
MLIR Specific Notes
-------------------
See notes for the 'as_strided' op.
"""))
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.DEBUG)
registry = OpRegistry()
populate(registry)
print("Registered operations:")
for m in registry.mappings:
print(" ", m)

View File

@ -1,205 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Generates ODS for a registry of ops."""
from typing import TextIO
import argparse
from contextlib import contextmanager
import importlib
import logging
import re
import sys
import textwrap
from .registry import *
_INDENT = " "
class OdsEmitter:
ods_prefix = "ATen_"
ods_suffix = "Op"
ods_value_template = "ATen_ImmutableTensorOp"
ods_ref_template = "ATen_RefTensorOp"
op_prefix = ""
def __init__(self, r: OpRegistry, out: TextIO):
super().__init__()
self.r = r
self.out = out
self.indent_level = 0
def emit_ods(self):
for op_m in self.r.mappings:
if isinstance(op_m, SimpleOpMapping):
self._emit_simple_op_mapping(op_m)
else:
logging.warn(f"Unrecognized op mapping type: {op_m!r}")
def _emit_simple_op_mapping(self, op_m: SimpleOpMapping):
identifier = (f"{self.ods_prefix}"
f"{_snakecase_to_camelcase(op_m.mlir_operation_name)}"
f"{self.ods_suffix}")
traits = []
if op_m.is_outref_form:
template_name = self.ods_ref_template
summary = "See non-inplace op variant."
description = ""
else:
template_name = self.ods_value_template
summary, description = _split_docstring(op_m.op_f.__doc__)
if not op_m.is_outref_form:
traits.append("NoSideEffect")
self.print(f"def {identifier}: {template_name}"
f"<{_quote(op_m.mlir_operation_name)}, ["
f"{', '.join(traits)}"
f"]> {{")
# Summary.
with self.indent():
self.print(f"let summary = {_quote(summary)};")
# Arguments.
with self.indent():
self.print("let arguments = (ins")
with self.indent():
operand_len = len(op_m.operand_map)
for index, (_, value_spec) in enumerate(op_m.operand_map):
is_last = index == operand_len - 1
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
end="\n" if is_last else ",\n")
self.print(");")
# Results (omitted if an outref/inplace form).
with self.indent():
if op_m.is_outref_form:
self.print("let results = (outs);")
else:
self.print("let results = (outs")
with self.indent():
result_len = len(op_m.result_map)
for index, (_, value_spec) in enumerate(op_m.result_map):
is_last = index == result_len - 1
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
end="\n" if is_last else ",\n")
self.print(");")
# Description and extra class declarations.
with self.indent():
if description:
quoted_description = _quote_multiline_docstring(
description + op_m.append_description,
indent_level=self.indent_level)
self.print(f"let description = {quoted_description};")
self.print("}\n")
@contextmanager
def indent(self, level=1):
self.indent_level += level
yield
self.indent_level -= level
assert self.indent_level >= 0, "Unbalanced indentation"
def print(self, s, *, end="\n", indent=True):
if indent and self.indent_level:
self.out.write(_INDENT * self.indent_level)
self.out.write(s)
self.out.write(end)
def _snakecase_to_camelcase(ident: str):
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
def _quote(s: str):
s = s.replace(r'"', r'\\"')
return f'"{s}"'
def _quote_multiline_docstring(s: str, indent_level: int = 0):
# TODO: Possibly find a python module to markdown the docstring for better
# document generation.
# Unlikely to contain the delimitter and since just a docstring, be safe.
s = s.replace("}]", "")
# Strip each line.
s = "\n".join([l.rstrip() for l in s.splitlines()])
indent = _INDENT * indent_level
s = textwrap.indent(s, indent + _INDENT)
return "[{\n" + s + "\n" + indent + "}]"
def _split_docstring(docstring: str):
"""Splits the docstring into a summary and description."""
if not docstring:
docstring = ""
lines = docstring.splitlines()
if not lines:
return "", ""
# Skip leading blank lines.
while lines and not lines[0]:
lines = lines[1:]
if len(lines) > 2:
return lines[0], "\n".join(lines[2:])
else:
return lines[0], ""
def main(args):
r = OpRegistry()
# Populate from modules that provide a populate() function.
op_modules = [args.op_module]
for m_name in op_modules:
logging.info(f"Populating from module: {m_name}")
m = importlib.import_module(m_name, package=__package__)
f = getattr(m, "populate")
f(r)
out = sys.stdout
# Write file header.
module_name = sys.modules["__main__"].__loader__.name
banner_lines = [
"//===-------------------------------------------------------*- tablegen -*-===//",
"//",
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
"// See https://llvm.org/LICENSE.txt for license information.",
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
"//",
"// Operation summaries and descriptions were systematically derived from public",
"// API docstrings and are licensed accordingly:",
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
"//===----------------------------------------------------------------------===//",
"// This file is automatically generated. Please do not edit.",
"// Generated via:",
f"// python -m {module_name} {' '.join(sys.argv[1:])}",
"//===----------------------------------------------------------------------===//",
"",
"",
]
banner_lines = [l.strip() for l in banner_lines]
out.write("\n".join(banner_lines))
emitter = OdsEmitter(r, out=out)
emitter.emit_ods()
def _create_argparse():
parser = argparse.ArgumentParser(prog="generate_ods")
parser.add_argument(
"--op_module",
default=".aten_ops",
help="Name of a python module for populating the registry")
return parser
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
parser = _create_argparse()
args = parser.parse_args()
main(args)

View File

@ -1,495 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Base classes and interfaces for mapping ops to MLIR.
The goal of this facility is to define the majority of op mappings by example,
from the actual Python invocation, using the PyTorch tracer to extract its
Graph IR, and introspecting that to determine mappings, metadata and types
for corresponding MLIR definitions. This is not meant to cover everything,
and it is expected that a number of hard ops should be mapped by hand.
The result of building such an OpRegistry should be a data structure that
can be used to generate ODS and tables for doing systematic imports from
a PyTorch Graph to corresponding MLIR module.
Example usage (fully automatic discovery):
>>> r = OpRegistry()
>>> r.op(torch.add,
TensorValue("input"),
TensorValue("other"),
alpha=ScalarValue()).with_outref_variant()
"""
from typing import Dict, List, Optional, Sequence, Tuple
import logging
import random
import textwrap
import torch
__all__ = [
"SimpleOpMapping",
"OpRegistry",
"LiteralValue",
"TensorOutRef",
"TensorValue",
]
def _is_same_value(value1, value2):
# Tensors are considered via reference equality.
value1_tensor = isinstance(value1, torch.Tensor)
value2_tensor = isinstance(value2, torch.Tensor)
if value1_tensor or value2_tensor:
return value1 is value2
# Everything else is value equality.
return value1 == value2
def _extract_immediate(node):
"""Extracts an immediate value from a node.
Supported node types:
prim::Constant
prim::ListConstruct
"""
# Constant
if node.kind() == "prim::Constant":
# Try to extract as different types.
try:
return node.t("value")
except RuntimeError:
pass
try:
return node.f("value")
except RuntimeError:
pass
try:
return node.i("value")
except RuntimeError:
pass
try:
return node.s("value")
except RuntimeError:
pass
return None
# List
elif node.kind() == "prim::ListConstruct":
return [_extract_immediate(i.node()) for i in node.inputs()]
else:
raise ValueError("Unrecognized immediate input node: {!r}".format(node))
def _read_attr_callback(instance, attr_name):
def f():
return getattr(instance, attr_name)
f.__doc__ = getattr(type(instance), attr_name).__doc__
return f
class ValueSpec:
"""Base class for inputs to operations.
This binds information about how the input is mapped to the MLIR operation.
"""
def __init__(self, name=None, mlir_ods_predicate: str = "AnyType"):
super().__init__()
self.name = name
self.mlir_ods_predicate = mlir_ods_predicate
def generate_example(self, index=0):
"""Generates an example value."""
raise NotImplementedError()
def __repr__(self):
return "{}({!r})".format(self.__class__.__name__, self.name)
class TensorValue(ValueSpec):
"""An input that is a tensor."""
def __init__(self,
name=None,
*,
example_size=None,
mlir_ods_predicate="ATen_AnyTensor",
**kwargs):
super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs)
if example_size is None:
example_size = (2, 3, 7) # No significance.
self.example_size = example_size
def generate_example(self, index=0):
return torch.rand(*self.example_size)
class TensorOutRef(ValueSpec):
"""A tensor that is passed by ref as an out parameter."""
def __init__(self,
name=None,
*,
example_size=None,
mlir_ods_predicate="ATen_AnyRefTensor",
**kwargs):
super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs)
if example_size is None:
example_size = (2, 3, 7) # No significance.
self.example_size = example_size
def generate_example(self, index=0):
return torch.rand(*self.example_size)
class LiteralValue(ValueSpec):
"""An input that is a literal value."""
def __init__(self,
name=None,
value=None,
mlir_ods_predicate="ATen_AnyScalar",
**kwargs):
super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs)
self.value = value
def generate_example(self, index=0):
if self.value is not None:
return self.value
return 1.0 + index # Generates a stable value.
class OpMapping:
"""Base class for things purporting to map an operation."""
__slots__ = []
class SimpleOpMapping(OpMapping):
"""Maps a PyTorch invocation to its MLIR representation."""
__slots__ = [
"append_description",
"mlir_operation_name",
"operand_map",
"op_args",
"op_arity",
"op_f",
"op_kind",
"op_kwargs",
"op_self",
"op_self_value_spec",
"outref_variant_value",
"result_map",
]
def __init__(self, op_f, *op_args, **op_kwargs):
super().__init__()
self.op_f = op_f
self.op_self = None
self.op_self_value_spec = None
self.op_args = op_args
self.op_kwargs = op_kwargs
self.outref_variant_value = None # type: Optional[TensorOutRef]
self.append_description = ""
# Set after finalize.
self.op_kind = None # type: Optional[str]
self.op_arity = -1 # type: int
self.operand_map = None # type: Optional[List[Tuple[int, ValueSpec]]]
self.result_map = None # type: Optional[List[Tuple[int, ValueSpec]]]
self.mlir_operation_name = None # type: Optional[str]
# Fixup self-calls.
# These are specified as tuples of (ValueSpec, "method_name")
if isinstance(self.op_f, tuple):
assert len(self.op_f) == 2, "Self-calls must be a 2-tuple"
self.op_self_value_spec, method_name = self.op_f
self.op_self = self.op_self_value_spec.generate_example(index=-1)
if method_name.startswith("@"):
# Create a pseudo-function to resolve the attribute.
self.op_f = _read_attr_callback(self.op_self, method_name[1:])
else:
self.op_f = getattr(self.op_self, method_name)
def __repr__(self):
return ("SimpleOp({kind!r}[{arity}] -> {name!s}, operands={operands!r}, "
"results={results!r})".format(kind=self.op_kind,
arity=self.op_arity,
name=self.mlir_operation_name,
operands=self.operand_map,
results=self.result_map))
def clone(self) -> "SimpleOpMapping":
copy = SimpleOpMapping(self.op_f, *self.op_args, **self.op_kwargs)
for name in [
"outref_variant_value", "op_kind", "op_arity", "operand_map",
"result_map", "mlir_operation_name"
]:
setattr(copy, name, getattr(self, name))
return copy
def with_outref_variant(self, value=None) -> "SimpleOpMapping":
"""Instructs the registry to also generate an outref variant.
This is done by cloning the op prior to finalizing and adding an out=
paramer.
"""
self.outref_variant_value = TensorOutRef() if value is None else value
return self
def with_torch_op_kind(self, op_kind: str) -> "SimpleOpMapping":
"""Uses a manually specified op kind (i.e. "aten::some_op")."""
self.op_kind = op_kind
return self
def with_operand_map(self, *mlir_operand_names):
"""Manually maps torch IR operands to mlir operand names."""
operand_map = []
for i, name_or_value_spec in enumerate(mlir_operand_names):
if isinstance(name_or_value_spec, ValueSpec):
value_spec = name_or_value_spec
else:
value_spec = self.find_value_spec_by_name(name_or_value_spec)
operand_map.append((i, value_spec))
self.operand_map = operand_map
return self
def with_append_description(self, description) -> "SimpleOpMapping":
"""Appends a description to the ODS."""
self.append_description += textwrap.dedent(description)
return self
@property
def all_arg_values(self) -> List[ValueSpec]:
"""Returns all arg values (either positional or kw)."""
args = list(self.op_args) + list(self.op_kwargs.values())
if self.op_self_value_spec:
args = [self.op_self_value_spec] + args
return args
@property
def is_outref_form(self) -> bool:
"""Whether the op contains an out parameter that aliases to the result."""
return any(isinstance(a, TensorOutRef) for a in self.all_arg_values)
def find_value_spec_by_name(self, name) -> ValueSpec:
"""Finds an argument ValueSpec by name.
Raises:
ValueError if not found.
"""
if self.op_self_value_spec and self.op_self_value_spec.name == name:
return self.op_self_value_spec
value_spec = self.op_kwargs.get(name)
if value_spec:
return value_spec
for value_spec in self.op_args:
if value_spec.name == name:
return value_spec
raise ValueError(f"Unknown value spec: {name}")
def generate_example(self) -> Tuple[Tuple, Dict]:
"""Generates an example signature for invoking the op.
Returns:
(tuple, dict) of positional and keyword args.
"""
index = 0
positional = list()
kw = dict()
for op_arg in self.op_args:
positional.append(op_arg.generate_example(index))
index += 1
for kw_name, kw_value in self.op_kwargs.items():
kw[kw_name] = kw_value.generate_example(index)
index += 1
return positional, kw
def finalize(self):
"""Finalizes the mapping once all hints have been applied."""
# Update the name on all args if undefined.
for index, op_arg in enumerate(self.op_args):
if op_arg.name is None:
op_arg.name = "arg%d".format(index)
for key, op_arg in self.op_kwargs.items():
if op_arg.name is None:
op_arg.name = key
# Create an example graph and configure from it.
self._configure_from_example()
# Determine default operation name.
if self.mlir_operation_name is None:
self._set_default_mlir_operation_name()
def _set_default_mlir_operation_name(self):
op_ns, op_name = self.op_kind.split("::", maxsplit=1)
# Since these are emitted into the "aten" dialect namespace, alias them
# to omit the prefix to distinguish from custom ops and others (which will
# have a prefix).
default_name = op_name if op_ns == "aten" else op_ns + "." + op_name
if op_ns == "aten":
default_name = op_name
else:
default_name = op_ns + "." + op_name
if self.is_outref_form:
default_name += ".inplace"
self.mlir_operation_name = default_name
def _configure_from_example(self):
# Trace the op so that we get an example graph like this:
# %0 : Float(2, 3, 7) = prim::Constant[value=<Tensor>]()
# %1 : Float(2, 3, 7) = prim::Constant[value=<Tensor>]()
# %2 : float = prim::Constant[value=3.]()
# %3 : Float(2, 3, 7) = aten::add(%0, %1, %2)
# return (%3)
# The def of the return value is expected to be the modeled op. The
# inputs to that op are expected to be captured constants that can be
# re-associated to the example inputs.
example_args, example_kwargs = self.generate_example()
def forward():
# logging.debug(
# f"Invoke {self.op_f} with *{example_args}, **{example_kwargs}")
return self.op_f(*example_args, **example_kwargs)
trace = torch.jit.trace(forward, tuple())
graph = trace.graph
logging.debug("Graph for op %r: %s", self.op_f, graph)
# Track up from the return node and assume this is our op.
return_node = graph.return_node()
return_inputs = list(return_node.inputs())
assert len(return_inputs) == 1, "Expected one input return"
op_node = return_inputs[0].node()
op_inputs = list(op_node.inputs())
logging.debug("Found op node: %r", op_node)
# Meta-data about the source op.
self.op_kind = op_node.kind()
self.op_arity = len(op_inputs)
if self.operand_map is None:
self.operand_map = self._associate_inputs(op_inputs, example_args,
example_kwargs)
# Results.
op_outputs = list(op_node.outputs())
if self.result_map is None:
if self.is_outref_form:
# Only support single outref results.
assert len(op_outputs) == 1, (
"For outref ops, only a single output is supported")
self.result_map = [(0, TensorOutRef("result"))]
else:
# Map results in order.
self.result_map = []
def result_name(i):
if len(op_outputs) == 1:
return "result"
else:
return "result%d" % i
for i, op_output in enumerate(op_outputs):
op_output_type = op_output.type()
if issubclass(type(op_output_type), torch.TensorType):
self.result_map.append((i, TensorValue(result_name(i))))
else:
raise ValueError(
"Unsupported op output type: {!r}".format(op_output_type))
return self
def _associate_inputs(self, op_inputs, example_args, example_kwargs):
"""Given inputs to a graph op node, associates to the input args.
This will match up example arguments with what was produced in the graph,
setting the operand_map.
Returns:
List of (input_index, ValueSpec) mapping inputs to the graph node to
provided values in the op definition.
"""
assert len(example_args) == len(self.op_args)
assert example_kwargs.keys() == self.op_kwargs.keys()
def find_arg(value):
if self.op_self is not None and _is_same_value(self.op_self, value):
return self.op_self_value_spec
for i, arg in enumerate(example_args):
if _is_same_value(arg, value):
return self.op_args[i]
for key, arg in example_kwargs.items():
if _is_same_value(arg, value):
return self.op_kwargs[key]
raise KeyError("Op input not in arguments: {!r} -> {!r}".format(
value, op_inputs))
operand_map = []
for i, op_input in enumerate(op_inputs):
input_node = op_input.node()
immediate_value = _extract_immediate(input_node)
if immediate_value is not None:
operand_map.append((i, find_arg(immediate_value)))
return operand_map
class OpRegistry:
"""Maintains a registry of op mappings."""
def __init__(self):
super().__init__()
self._mappings = []
self._pending_mapping = None
def op(self, op_f, *op_args, **op_kwargs):
"""Forwards to the SimpleOpMapping constructor and adds it.
The mapping is not finalized until either the registry is finalized or the
next op mapping is added. This allows tweaks to the mapping to be done
inline prior to performing detailed introspection.
Returns:
The SimpleOpMapping instance.
"""
self._finalize_pending()
m = SimpleOpMapping(op_f, *op_args, **op_kwargs)
self._pending_mapping = m
return m
@property
def mappings(self) -> Sequence[OpMapping]:
"""Returns the list of OpMapping.
Returns:
Sequence of OpMapping concrete classes (most commonly SimpleOpMapping).
"""
self._finalize_pending()
return self._mappings
def _finalize_pending(self):
if not self._pending_mapping:
return
outref_mapping = None
pending_mapping = self._pending_mapping
self._pending_mapping = None
if pending_mapping.outref_variant_value:
# Generate a variant (with an out= form).
outref_mapping = pending_mapping.clone()
outref_mapping.op_kwargs["out"] = outref_mapping.outref_variant_value
outref_mapping.outref_variant_value = None
# Finalize the original.
pending_mapping.finalize()
self._mappings.append(pending_mapping)
# Finalize the outref form if generated.
if outref_mapping:
outref_mapping.finalize()
self._mappings.append(outref_mapping)

View File

@ -0,0 +1,477 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Queries the pytorch op registry and generates ODS and CC sources for the ops.
"""
from typing import Dict, List, Optional, TextIO, Sequence, Tuple
import argparse
from contextlib import contextmanager
import importlib
import logging
import re
import sys
import textwrap
# Note that this utility exists only in the c-extension.
from _torch_mlir import get_registered_ops
def _create_argparse():
parser = argparse.ArgumentParser(prog="generate_ods")
parser.add_argument("--ods_td_file",
required=True,
help="File to write the generated ODS to")
parser.add_argument("--ods_impl_file",
required=True,
help="CC include file to include in ops implementation")
parser.add_argument("--debug_op_reg_file",
help="Write out a file of op registrations")
return parser
def main(args):
reg_ops = _load_ops_as_dict()
if args.debug_op_reg_file:
with open(args.debug_op_reg_file, "w") as debug_ops_file:
dump_registered_ops(debug_ops_file, reg_ops)
with open(args.ods_td_file, "w") as ods_file, open(args.ods_impl_file,
"w") as impl_file:
ods_emitter = OdsEmitter(ods_file)
ods_emitter.print(ODS_BANNER)
impl_emitter = CCImplEmitter(impl_file)
impl_emitter.print(CC_IMPL_BANNER)
generator = OpGenerator(reg_ops, ods_emitter, impl_emitter)
generate_ops(generator)
def generate_ops(g: "OpGenerator"):
# "Binary"-ops.
# There are some variation in these, so we spell them out in case if they
# need individual customization.
g.print_banner("Binary arithmetic ops")
g.ordinary_binary_op("aten::add(Tensor,Tensor,Scalar)", "AddOp", "add")
g.ordinary_binary_op("aten::atan2(Tensor,Tensor)", "Atan2Op", "atan2")
g.ordinary_binary_op("aten::div(Tensor,Tensor)", "DivOp", "div")
g.ordinary_binary_op("aten::floor_divide(Tensor,Tensor)", "FloorDivideOp",
"floor_divide")
g.ordinary_binary_op("aten::mul(Tensor,Tensor)", "MulOp", "mul")
g.ordinary_binary_op("aten::remainder(Tensor,Tensor)", "RemainderOp",
"remainder")
g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp",
"true_divide")
# Unary-ops. These are all the same so just name munge them.
g.print_banner("Unary arithmetic ops")
for uname in [
"abs", "acos", "angle", "asin", "atan", "ceil", "conj", "cos", "cosh",
"digamma", "erf", "erfc", "erfinv", "exp", "expm1", "floor", "frac",
"lgamma", "log", "log10", "log1p", "log2", "neg", "relu", "reciprocal",
"round", "rsqrt", "sigmoid", "sign", "sin", "sinh", "sqrt", "tan", "tanh",
"trunc"
]:
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
f"{snakecase_to_camelcase(uname)}Op", uname)
def dump_registered_ops(outfile, reg_ops_dict):
for k in sorted(reg_ops_dict.keys()):
attr_dict = reg_ops_dict[k]
outfile.write(f"OP '{k}':\n")
for attr_name, attr_value in attr_dict.items():
outfile.write(f" {attr_name} = {attr_value!r}\n")
outfile.write("\n")
class OpGenerator:
def __init__(self, reg_ops, ods_emitter, impl_emitter):
super().__init__()
self.reg_ops = reg_ops
self.ods_emitter = ods_emitter
self.impl_emitter = impl_emitter
def print_banner(self, text):
for em in (self.ods_emitter, self.impl_emitter):
em.print(
"// -----------------------------------------------------------------------------"
)
em.print(f"// {text}")
em.print(
"// -----------------------------------------------------------------------------"
)
em.print("")
def ordinary_binary_op(self, kernel_sig, ods_name, op_name):
""""Binary"-ops. These ops typically have:
- '.Tensor' variant where the second arg is a Tensor
- '.Scalar' variant where the second arg is a Scalar
- An '.out' variant which contains a final "outref" argument
Actual suffixes vary and the rules are set up to match anything
that comes in in one of these forms.
In addition, most of these have in-place versions, possibly of the
'.Tensor' and '.Scalar' variants above.
Note that many of these have more than two arguments (i.e. alpha/beta
scalars trailing and such), but they are not relevant for
matching/conversions (if they are more than pass-through, then have a
dedicated rule).
We generally canonicalize all of these forms to a single recognized op
by:
- Enabling the flag to promotTrailingOutTensor
- Enabling the flag matchInplaceVariant
- Setting all arguments and returns to kImmutableTensor
- Enabling kPromoteScalarToTensor on the second argument.
"""
reg_record = self._get_reg_record(kernel_sig)
ods_ins, arg_type_flags = self._map_sigtypes(
reg_record["arguments"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
"Tensor:1": "AnyTorchImmutableTensor",
"Scalar:1": "AnyTorchImmutableTensor",
"Scalar": "AnyTorchScalarType",
},
flag_transforms={
":0": ["kImmutableTensor"],
":1": ["kImmutableTensor", "kPromoteScalar"],
})
ods_outs, return_type_flags = self._map_sigtypes(
reg_record["returns"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
self.ods_emitter.emit_opdef(ods_name,
op_name,
reg_record,
ods_ins=ods_ins,
ods_outs=ods_outs,
traits=["NoSideEffect"])
self.impl_emitter.emit_kernel_methods(ods_name,
reg_record,
arg_type_flags=arg_type_flags,
return_type_flags=return_type_flags,
promote_trailing_out_tensor=True)
def ordinary_unary_op(self, kernel_sig, ods_name, op_name):
"""Unary ops.
These take and return a tensor and typically have an out and inplace
variant (they may not but we generate patterns to match anyway).
"""
reg_record = self._get_reg_record(kernel_sig)
ods_ins, arg_type_flags = self._map_sigtypes(
reg_record["arguments"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
ods_outs, return_type_flags = self._map_sigtypes(
reg_record["returns"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
self.ods_emitter.emit_opdef(ods_name,
op_name,
reg_record,
ods_ins=ods_ins,
ods_outs=ods_outs,
traits=["NoSideEffect"])
self.impl_emitter.emit_kernel_methods(ods_name,
reg_record,
arg_type_flags=arg_type_flags,
return_type_flags=return_type_flags,
promote_trailing_out_tensor=True)
def _get_reg_record(self, kernel_sig):
"""Gets the op-dict for a given registered op name.
Args:
kernel_sig: Signature of the kernel to find.
Returns:
Dict of the registration record.
"""
record = self.reg_ops.get(kernel_sig)
if record:
return record
# Try to give a nice "did you mean" style error, since this happens
# so much.
kernel_name, *rest = kernel_sig.split("(", maxsplit=1)
dym_list = [k for k in self.reg_ops.keys() if k.startswith(kernel_name)]
dym_message = '\n '.join(dym_list)
raise ValueError(f"Could not find registry op matching '{kernel_sig}'. "
f"Possible matches:\n {dym_message}")
def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str],
flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]:
"""Maps a list of signature entries to ods dags and flag lists.
The torch signature list contains dicts that minimally have keys 'name' and
'type', representing torch names and types. Returns a corresponding
list of 2-tuples of (ods_name, ods_type).
The provided type_transforms is a dict of type substitutions, one of which
must match for each entry in the list. The keys can be either a verbatim
torch type (i.e. "Tensor") an index in the list (i.e. ":0") or a type and
index (i.e. "Tensor:0").
Similarly, flag_transforms matches its keys in the same way and maps to
a list of KernelValueConversion constants that make up arg/return specific
conversion flags.
Returns:
- An ods dag list of (ods_name, ods_type) tuples
- List of (torch_type, [conversion_flag]) for specifying conversions.
"""
# Generate to ods dag list.
ods_dag_list = []
for i, sigitem in enumerate(siglist):
torch_name = sigitem["name"]
torch_type = sigitem["type"]
# Look up the type transform.
ods_type = (type_transforms.get(f"{torch_type}:{i}") or
type_transforms.get(f":{i}") or
type_transforms.get(torch_type))
if not ods_type:
raise ValueError(f"Signature item {i}, type {torch_type} did not match "
f"a type transform {type_transforms}")
ods_dag_list.append((torch_name, ods_type))
# Generate the type conversion flags.
type_flag_list = []
for i, sigitem in enumerate(siglist):
torch_type = sigitem["type"]
# Look up the type transform.
flags = (flag_transforms.get(f"{torch_type}:{i}") or
flag_transforms.get(f":{i}") or flag_transforms.get(torch_type))
if not flags:
flags = []
type_flag_list.append((torch_type, flags))
return ods_dag_list, type_flag_list
class EmitterBase:
_INDENT = " "
def __init__(self, out: TextIO):
super().__init__()
self.out = out
self.indent_level = 0
@contextmanager
def indent(self, level=1):
self.indent_level += level
yield
self.indent_level -= level
assert self.indent_level >= 0, "Unbalanced indentation"
def print(self, s, *, end="\n", indent=True):
if indent and self.indent_level:
self.out.write(self._INDENT * self.indent_level)
self.out.write(s)
self.out.write(end)
def quote(self, s: str):
s = s.replace(r'"', r'\\"')
return f'"{s}"'
def quote_multiline_docstring(self, s: str, indent_level: int = 0):
# TODO: Possibly find a python module to markdown the docstring for better
# document generation.
# Unlikely to contain the delimitter and since just a docstring, be safe.
s = s.replace("}]", "")
# Strip each line.
s = "\n".join([l.rstrip() for l in s.splitlines()])
indent = self._INDENT * indent_level
s = textwrap.indent(s, indent + self._INDENT)
return "[{\n" + s + "\n" + indent + "}]"
class OdsEmitter(EmitterBase):
ods_def_prefix = "aten_"
ods_def_suffix = ""
ods_template_name = "aten_Op"
def emit_opdef(self,
ods_def_name: str,
mnemonic: str,
reg_record: Dict,
ods_ins: List[Tuple[str, str]],
ods_outs: List[Tuple[str, str]],
traits: Sequence[str] = (),
summary: Optional[str] = None):
# Def first-line.
full_traits = list(traits)
full_traits.append(
"DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>")
full_traits.append("DeclareOpInterfaceMethods<TorchKernelOpInterface>")
identifier = f"{self.ods_def_prefix}{ods_def_name}{self.ods_def_suffix}"
self.print(f"def {identifier}: {self.ods_template_name}"
f"<{self.quote(mnemonic)}, ["
f"{', '.join(full_traits)}"
f"]> {{")
with self.indent():
# Summary.
if not summary:
summary = f"Recognized op for kernel {reg_record['name'][0]}"
self.print(f"let summary = {self.quote(summary)};")
# Arguments.
self.print("let arguments = (ins")
with self.indent():
self._emit_dag_list_body(ods_ins)
self.print(");")
# Results.
self.print("let results = (outs")
with self.indent():
self._emit_dag_list_body(ods_outs)
self.print(");")
# Def last-line.
self.print("}\n")
def _emit_dag_list_body(self, items):
"""Emits a dag of (name, type) pairs."""
for index, (ods_name, ods_type) in enumerate(items):
is_last = index == len(items) - 1
ods_namespec = f":${ods_name}" if ods_name else ""
self.print(f"{ods_type}{ods_namespec}", end="\n" if is_last else ",\n")
class CCImplEmitter(EmitterBase):
def emit_kernel_methods(self,
ods_def_name: str,
reg_record,
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
promote_trailing_out_tensor=False):
# getTorchKernelMetadata() method.
self.print(
f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{")
with self.indent():
self.print("return getTorchBuildKernelMetadata();")
self.print("}\n")
# getTorchBuildKernelMetadata() method.
kernel_name = reg_record["name"][0]
self.print(
f"const Torch::BuildKernelMetadata &{ods_def_name}::getTorchBuildKernelMetadata() {{"
)
with self.indent():
self.print("using KVC = Torch::KernelValueConversion::BitMask;")
self.print("static Torch::BuildKernelMetadata metadata = ([]() {")
with self.indent():
self.print("Torch::BuildKernelMetadata m;")
self.print(f"m.kernelName = {self.quote(kernel_name)};")
if promote_trailing_out_tensor:
self.print("m.promoteTrailingOutTensor = true;")
# Arg types/flags.
arg_types = self._format_cpp_str_initlist(
[t[0] for t in arg_type_flags])
self.print(f"m.addArgTypes({arg_types});")
arg_flags = self._format_cpp_kvc_initlist(
[t[1] for t in arg_type_flags])
self.print(f"m.addArgConversions({arg_flags});")
# Returns types/flags.
ret_types = self._format_cpp_str_initlist(
[t[0] for t in return_type_flags])
self.print(f"m.addReturnTypes({ret_types});")
ret_flags = self._format_cpp_kvc_initlist(
[t[1] for t in return_type_flags])
self.print(f"m.addReturnConversions({ret_flags});")
self.print("return m;")
self.print("})();")
self.print("return metadata;")
self.print("}")
def _format_cpp_str_initlist(self, strings):
quoted = [self.quote(s) for s in strings]
joined = ", ".join(quoted)
return "{" + joined + "}"
def _format_cpp_kvc_initlist(self, const_name_lists):
def or_flags(flag_names):
if not flag_names:
return "KVC::kNone"
return "|".join([f"KVC::{n}" for n in flag_names])
or_d = [or_flags(l) for l in const_name_lists]
joined = ", ".join(or_d)
return "{" + joined + "}"
def snakecase_to_camelcase(ident: str):
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
def _load_ops_as_dict():
# Returns a list of dicts, each with a name that is a tuple of the form:
# (kernel_signature, variant)
# The kernel signature is a reified form of the argument type signature
# used throughout PyTorch:
# namespace::kernel_name(type1,type2)
def reify_signature(reg_op):
kernel_name, unused_variant = reg_op["name"]
arg_types = [arg["type"] for arg in reg_op["arguments"]]
return f"{kernel_name}({','.join(arg_types)})"
reg_ops_list = get_registered_ops()
return {reify_signature(reg_op): reg_op for reg_op in reg_ops_list}
def _get_main_module_name():
return sys.modules["__main__"].__loader__.name
ODS_BANNER = "\n".join([
"//===-------------------------------------------------------*- tablegen -*-===//",
"//",
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
"// See https://llvm.org/LICENSE.txt for license information.",
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
"//",
"// Operation summaries and descriptions were systematically derived from public",
"// API docstrings and are licensed accordingly:",
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
"//===----------------------------------------------------------------------===//",
"// This file is automatically generated. Please do not edit.",
"// Generated via:",
f"// python -m {_get_main_module_name()}",
"//===----------------------------------------------------------------------===//",
"",
"",
])
CC_IMPL_BANNER = "\n".join([
"//===-------------------------------------------------------------*- cc -*-===//",
"//",
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
"// See https://llvm.org/LICENSE.txt for license information.",
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception", "//",
"// Operation summaries and descriptions were systematically derived from public",
"// API docstrings and are licensed accordingly:",
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
"//===----------------------------------------------------------------------===//",
"// This file is automatically generated. Please do not edit.",
"// Generated via:", f"// python -m {_get_main_module_name()}",
"//===----------------------------------------------------------------------===//",
"", "", "// clang-format off"
])
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
parser = _create_argparse()
args = parser.parse_args()
main(args)

View File

@ -25,44 +25,7 @@ class aten_Op<string mnemonic, list<OpTrait> traits = [StatisticsOpInterface]> :
// Most ops are automatically generated from pytorch specs.
include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td"
def aten_AddOp: aten_Op<"add", [
NoSideEffect, TorchBuildableKernelOpInterface, TorchKernelOpInterface,
StatisticsOpInterface]> {
let arguments = (
ins AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$other,
AnyTorchScalarType:$alpha
);
let results = (outs AnyTorchImmutableTensor);
let summary = "aten add operator";
let description = [{
AddOp
aten add operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
Torch::KernelMetadata getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
static const Torch::BuildKernelMetadata &getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::add";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Scalar"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
}];
}
include "npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td"
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> {

View File

@ -0,0 +1,781 @@
//===-------------------------------------------------------------*- cc -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Operation summaries and descriptions were systematically derived from public
// API docstrings and are licensed accordingly:
// https://github.com/pytorch/pytorch/blob/master/LICENSE
//===----------------------------------------------------------------------===//
// This file is automatically generated. Please do not edit.
// Generated via:
// python -m torch_mlir_utils.codegen.torch_signature_ods_gen
//===----------------------------------------------------------------------===//
// clang-format off
// -----------------------------------------------------------------------------
// Binary arithmetic ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata AddOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::add";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Scalar"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::atan2";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::div";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::floor_divide";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::mul";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::remainder";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::true_divide";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
// -----------------------------------------------------------------------------
// Unary arithmetic ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata AbsOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::abs";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::acos";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::angle";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::asin";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::atan";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::ceil";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::conj";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::cos";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::cosh";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::digamma";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::erf";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::erfc";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::erfinv";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::exp";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::expm1";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::floor";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::frac";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::lgamma";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::log";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::log10";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::log1p";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::log2";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::neg";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::relu";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::reciprocal";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::round";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::rsqrt";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::sigmoid";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::sign";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::sin";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::sinh";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::sqrt";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::tan";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::tanh";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::trunc";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,654 @@
// This file is (mostly) automatically generated. Please do not edit.
//===- ATenOps.td ------------------------------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
#define NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
def aten_AddUnderOp: aten_Op<"add_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$other,
AnyScalar:$alpha
);
let summary = "aten add_ operator";
let description = [{
AddUnderOp
aten add_ operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyType:$size,
AnyType:$stride
);
let summary = "aten as_strided operator";
let description = [{
AsStridedOp
aten as_strided operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$input,
AnyTensor:$weight,
AnyTensor:$bias,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$transposed,
AnyType:$output_padding,
AnyScalar:$groups
);
let summary = "aten convolution_overrideable operator";
let description = [{
ConvolutionOverrideableOp
aten convolution_overrideable operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$input,
AnyTensor:$weight,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$transposed,
AnyType:$output_padding,
AnyScalar:$groups
);
let summary = "aten convolution_backward_overrideable operator";
let description = [{
ConvolutionBackwardOverrideableOp
aten convolution_backward_overrideable operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$other
);
let summary = "aten div_ operator";
let description = [{
DivUnderOp
aten div_ operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyType:$size,
AnyScalar:$implicit
);
let summary = "aten expand operator";
let description = [{
ExpandOp
aten expand operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim,
AnyScalar:$half_to_float
);
let summary = "aten _log_softmax operator";
let description = [{
LogSoftmaxOp
aten _log_softmax operator
}];
}
def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$output,
AnyScalar:$dim,
AnyTensor:$self
);
let summary = "aten _log_softmax_backward_data operator";
let description = [{
LogSoftmaxBackwardDataOp
aten _log_softmax_backward_data operator
}];
}
def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self
);
let summary = "aten mean operator";
let description = [{
MeanOp
aten mean operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_MmOp: aten_Op<"mm", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$mat2
);
let summary = "aten mm operator";
let description = [{
MmOp
aten mm operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_MulUnderOp: aten_Op<"mul_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$other
);
let summary = "aten mul_ operator";
let description = [{
MulUnderOp
aten mul_ operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NativeBatchNormOp: aten_Op<"native_batch_norm", [NoSideEffect, StatisticsOpInterface]>,
// FIXME: not quite automatically generated names
Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> {
let arguments = (
ins AnyTensor:$input,
AnyTensor:$weight,
AnyTensor:$bias,
AnyTensor:$running_mean,
AnyTensor:$running_var,
AnyScalar:$training,
AnyScalar:$momentum,
AnyScalar:$eps
);
let summary = "aten native_batch_norm operator";
let description = [{
NativeBatchNormOp
aten native_batch_norm operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NativeBatchNormBackwardOp: aten_Op<"native_batch_norm_backward", [NoSideEffect, StatisticsOpInterface]>,
// FIXME: not quite automatically generated
Results<(outs AnyTensor:$dx, AnyTensor:$dm, AnyTensor:$dv)> {
let arguments = (
ins AnyTensor:$grad_out,
AnyTensor:$input,
AnyTensor:$weight,
AnyTensor:$running_mean,
AnyTensor:$running_var,
AnyTensor:$save_mean,
AnyTensor:$save_invstd,
AnyScalar:$train,
AnyScalar:$eps
);
let summary = "aten native_batch_norm_backward operator";
let description = [{
NativeBatchNormBackwardOp
aten native_batch_norm_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self
);
let summary = "aten relu_ operator";
let description = [{
ReluUnderOp
aten relu_ operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim
);
let summary = "aten size operator";
let description = [{
SizeOp
aten size operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim
);
let summary = "aten squeeze operator";
let description = [{
SqueezeOp
aten squeeze operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_SumOp: aten_Op<"sum", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyType:$dim,
AnyScalar:$keepdim
);
let summary = "aten sum operator";
let description = [{
SumOp
aten sum operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_TOp: aten_Op<"t", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self
);
let summary = "aten t operator";
let description = [{
TOp
aten t operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ThresholdBackwardOp: aten_Op<"threshold_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyScalar:$threshold
);
let summary = "aten threshold_backward operator";
let description = [{
ThresholdBackwardOp
aten threshold_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_UnsqueezeOp: aten_Op<"unsqueeze", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim
);
let summary = "aten unsqueeze operator";
let description = [{
UnsqueezeOp
aten unsqueeze operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_SubOp: aten_Op<"sub", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$other,
AnyScalar:$alpha
);
let summary = "aten sub operator";
let description = [{
SubOp
aten sub operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_SubUnderOp: aten_Op<"sub_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$other,
AnyScalar:$alpha
);
let summary = "aten sub_ operator";
let description = [{
SubUnderOp
aten sub_ operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_AddmmOp: aten_Op<"addmm", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$mat1,
AnyTensor:$mat2,
AnyScalar:$beta,
AnyScalar:$alpha
);
let summary = "aten addmm operator";
let description = [{
AddmmOp
aten addmm operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ViewOp: aten_Op<"view", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyType:$size
);
let summary = "aten view operator";
let description = [{
ViewOp
aten view operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim,
AnyTensor:$index,
AnyScalar:$sparse_grad
);
let summary = "aten gather operator";
let description = [{
GatherOp
aten gather operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index
);
let summary = "aten nll_loss_forward operator";
let description = [{
NllLossForwardOp
aten nll_loss_forward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index,
AnyTensor:$total_weight
);
let summary = "aten nll_loss_backward operator";
let description = [{
NllLossBackwardOp
aten nll_loss_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index
);
let summary = "aten nll_loss2d_forward operator";
let description = [{
NllLoss2dForwardOp
aten nll_loss2d_forward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index,
AnyTensor:$total_weight
);
let summary = "aten nll_loss2d_backward operator";
let description = [{
NllLoss2dBackwardOp
aten nll_loss2d_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$min_val,
AnyScalar:$max_val
);
let summary = "aten hardtanh operator";
let description = [{
HardtanhOp
aten hardtanh operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_HardtanhBackwardOp: aten_Op<"hardtanh_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyScalar:$min_val,
AnyScalar:$max_val
);
let summary = "aten hardtanh_backward operator";
let description = [{
HardtanhBackwardOp
aten hardtanh_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_HardtanhUnderOp: aten_Op<"hardtanh_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$min_val,
AnyScalar:$max_val
);
let summary = "aten hardtanh_ operator";
let description = [{
HardtanhUnderOp
aten hardtanh_ operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_AdaptiveAvgPool2dOp: aten_Op<"_adaptive_avg_pool2d", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyType:$output_size
);
let summary = "aten _adaptive_avg_pool2d operator";
let description = [{
AdaptiveAvgPool2dOp
aten _adaptive_avg_pool2d operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self
);
let summary = "aten _adaptive_avg_pool2d_backward operator";
let description = [{
AdaptiveAvgPool2dBackwardOp
aten _adaptive_avg_pool2d_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_MaxPool2dWithIndicesOp: aten_Op<"max_pool2d_with_indices", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyType:$kernel_size,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$ceil_mode
);
let summary = "aten max_pool2d_with_indices operator";
let description = [{
MaxPool2dWithIndicesOp
aten max_pool2d_with_indices operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_MaxPool2dWithIndicesBackwardOp: aten_Op<"max_pool2d_with_indices_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyType:$kernel_size,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$ceil_mode,
AnyTensor:$indices
);
let summary = "aten max_pool2d_with_indices_backward operator";
let description = [{
MaxPool2dWithIndicesBackwardOp
aten max_pool2d_with_indices_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
#endif // NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS

View File

@ -28,7 +28,11 @@ enum BitMask {
// Coerce/require a mutable tensor value.
kMutableTensor = 4,
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kMutableTensor)
// If the source is a Scalar and the target is a Tensor, promotes
// to a 0d tensor.
kPromoteScalar = 8,
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar)
};
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
} // namespace KernelValueConversion

View File

@ -108,3 +108,5 @@ void ATenDialect::initialize() {
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc"

View File

@ -55,33 +55,6 @@ std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
return toReturn;
}
// add
std::map<std::string, uint64_t> AddOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
TensorType resultTy = getResult().getType().cast<TensorType>();
TensorType aType = getOperand(0).getType().cast<TensorType>();
Type bType = getOperand(1).getType();
uint64_t ofm_volume = getTensorVolume(resultTy);
toReturn["ops:+"] = ofm_volume;
toReturn["result:0:activation_out"] = ofm_volume;
// Find the size of the A and B operands
uint64_t a_volume = getTensorVolume(aType);
uint64_t b_volume = getTensorVolume(bType);
toReturn["operand:0:activation_in"] = a_volume;
toReturn["operand:1:activation_in"] = b_volume;
toReturn["reads"] = a_volume + b_volume;
toReturn["writes"] = ofm_volume;
return toReturn;
}
// add_
std::map<std::string, uint64_t> AddUnderOp::getStatistics() {
@ -231,33 +204,6 @@ ConvolutionBackwardOverrideableOp::getStatistics() {
return getConv2dBackwardStatistics(*this, groups);
}
// div
std::map<std::string, uint64_t> DivOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
TensorType resultTy = getResult().getType().cast<TensorType>();
TensorType aType = getOperand(0).getType().cast<TensorType>();
Type bType = getOperand(1).getType();
uint64_t ofm_volume = getTensorVolume(resultTy);
toReturn["ops:/"] = ofm_volume;
toReturn["result:0:activation_out"] = ofm_volume;
// Find the size of the A and B operands
uint64_t a_volume = getTensorVolume(aType);
uint64_t b_volume = getTensorVolume(bType);
toReturn["operand:0:activation_in"] = a_volume;
toReturn["operand:1:activation_in"] = b_volume;
toReturn["reads"] = a_volume + b_volume;
toReturn["writes"] = ofm_volume;
return toReturn;
}
// div_
std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
@ -467,32 +413,6 @@ std::map<std::string, uint64_t> MmOp::getStatistics() {
return getMMOpStatistics(*this);
}
// mul
std::map<std::string, uint64_t> MulOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
TensorType resultTy = getResult().getType().cast<TensorType>();
TensorType aType = getOperand(0).getType().cast<TensorType>();
Type bType = getOperand(1).getType();
uint64_t ofm_volume = getTensorVolume(resultTy);
toReturn["ops:*"] = ofm_volume;
toReturn["result:0:activation_out"] = ofm_volume;
// Find the size of the A and B operands
uint64_t a_volume = getTensorVolume(aType);
uint64_t b_volume = getTensorVolume(bType);
toReturn["operand:0:activation_in"] = a_volume;
toReturn["operand:1:activation_in"] = b_volume;
toReturn["reads"] = a_volume + b_volume;
toReturn["writes"] = ofm_volume;
return toReturn;
}
// mul_
std::map<std::string, uint64_t> MulUnderOp::getStatistics() {
@ -668,23 +588,6 @@ std::map<std::string, uint64_t> NllLoss2dBackwardOp::getStatistics() {
return toReturn;
}
// neg op
std::map<std::string, uint64_t> NegOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
auto insize = getTensorVolume(getOperand().getType());
auto outsize = getTensorVolume(getResult().getType());
toReturn["reads"] = toReturn["operand:0:activation_in"] = insize;
toReturn["writes"] = toReturn["result:0:activation_out"] = outsize;
return toReturn;
}
// relu
// std::map<std::string, uint64_t> ReLUOp::getStatistics() {
// return getReLUOpStatistics(*this);
// }
std::map<std::string, uint64_t> ReluOp::getStatistics() {
return getReLUOpStatistics(*this);
}
// std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() {
// return getReLUOpStatistics(*this);
// }

View File

@ -47,6 +47,7 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
// Immutable tensor conversion.
if (flag & KVC::kImmutableTensor) {
// TODO: Support the kPromoteScalar flag.
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
return None;

View File

@ -1,15 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-add-0": {
// CHECK-NEXT: "activation_in": 12,
// CHECK-NEXT: "activation_out": 6,
// CHECK-NEXT: "ops:+": 6,
// CHECK-NEXT: "reads": 12,
// CHECK-NEXT: "writes": 6
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
// CHECK-CONVERSION-LABEL: @graph
func @graph(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
%1 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%2 = "aten.add"(%arg0, %arg1, %1) : (tensor<1x2x3xf32>, tensor<1x2x3xf32>, i32) -> tensor<1x2x3xf32>
"std.return"(%2) : (tensor<1x2x3xf32>) -> ()
}

View File

@ -1,15 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-relu-0": {
// CHECK-NEXT: "activation_in": 6,
// CHECK-NEXT: "activation_out": 6,
// CHECK-NEXT: "ops:>": 6,
// CHECK-NEXT: "reads": 6,
// CHECK-NEXT: "writes": 6
module {
func @graph(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
%0 = "aten.relu"(%arg0) : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
"std.return"(%0) : (tensor<1x2x3xf32>) -> ()
}
}

View File

@ -1,61 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-native_batch_norm-0": {
// CHECK-LABEL: "L1-relu-0": {
// CHECK-LABEL: "L2-_convolution-0": {
// CHECK-LABEL: "L3-native_batch_norm-1": {
// CHECK-LABEL: "L4-relu-1": {
// CHECK-LABEL: "L5-_convolution-1": {
// CHECK-LABEL: "L6-native_batch_norm-2": {
// CHECK-LABEL: "L7-relu-2": {
// CHECK-LABEL: "L8-_convolution-2": {
// CHECK-LABEL: "L9-add-0": {
module {
func @graph(%arg0: tensor<1x16x128x128xf32>, %arg1: tensor<1x16x128x128xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>, %arg5: tensor<16xf32>, %arg6: tensor<8x16x1x1xf32>, %arg7: tensor<8xf32>, %arg8: tensor<8xf32>, %arg9: tensor<8xf32>, %arg10: tensor<8xf32>, %arg11: tensor<8xf32>, %arg12: tensor<8x8x3x3xf32>, %arg13: tensor<8xf32>, %arg14: tensor<8xf32>, %arg15: tensor<8xf32>, %arg16: tensor<8xf32>, %arg17: tensor<8xf32>, %arg18: tensor<16x8x1x1xf32>, %arg19: tensor<16xf32>) -> tensor<1x16x128x128xf32> {
%0 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%1 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
%2 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
%3:3 = "aten.native_batch_norm"(%arg1, %arg2, %arg3, %arg4, %arg5, %0, %1, %2) : (tensor<1x16x128x128xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i1, f32, f32) -> (tensor<1x16x128x128xf32>, tensor<16xf32>, tensor<16xf32>)
%4 = "aten.relu"(%3#0) : (tensor<1x16x128x128xf32>) -> tensor<1x16x128x128xf32>
%5 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%6 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%7 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%8 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%9 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%10 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%11 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%12 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%13 = "aten._convolution"(%4, %arg6, %arg7, %5, %6, %7) : (tensor<1x16x128x128xf32>, tensor<8x16x1x1xf32>, tensor<8xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x8x128x128xf32>
%14 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%15 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
%16 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
%17:3 = "aten.native_batch_norm"(%13, %arg8, %arg9, %arg10, %arg11, %14, %15, %16) : (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, i1, f32, f32) -> (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>)
%18 = "aten.relu"(%17#0) : (tensor<1x8x128x128xf32>) -> tensor<1x8x128x128xf32>
%19 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%20 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%21 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%22 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%23 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%24 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%25 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%26 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%27 = "aten._convolution"(%18, %arg12, %arg13, %19, %20, %21) : (tensor<1x8x128x128xf32>, tensor<8x8x3x3xf32>, tensor<8xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x8x128x128xf32>
%28 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%29 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
%30 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
%31:3 = "aten.native_batch_norm"(%27, %arg14, %arg15, %arg16, %arg17, %28, %29, %30) : (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, i1, f32, f32) -> (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>)
%32 = "aten.relu"(%31#0) : (tensor<1x8x128x128xf32>) -> tensor<1x8x128x128xf32>
%33 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%34 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%35 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%36 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%37 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%38 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%39 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%40 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%41 = "aten._convolution"(%32, %arg18, %arg19, %33, %34, %35) : (tensor<1x8x128x128xf32>, tensor<16x8x1x1xf32>, tensor<16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x16x128x128xf32>
%42 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%43 = "aten.add"(%arg0, %41, %42) : (tensor<1x16x128x128xf32>, tensor<1x16x128x128xf32>, i32) -> tensor<1x16x128x128xf32>
return %43 : tensor<1x16x128x128xf32>
}
}