diff --git a/build_tools/update_aten_ods.sh b/build_tools/update_aten_ods.sh new file mode 100755 index 000000000..7ffda337a --- /dev/null +++ b/build_tools/update_aten_ods.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Updates the ATen dialect generated code from the PyTorch op registry. +# Requires that the project has been built and that PyTorch support is enabled. +set -e + +src_dir="$(realpath $(dirname $0)/..)" +build_dir="$(realpath "${NPCOMP_BUILD_DIR:-$src_dir/build}")" +aten_dir="${src_dir}/include/npcomp/Dialect/ATen/IR" + +export PYTHONPATH="${build_dir}/python" + +python -m torch_mlir_utils.codegen.torch_signature_ods_gen \ + --ods_td_file="${aten_dir}/GeneratedATenOps.td" \ + --ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/aten_ops.py b/frontends/pytorch/python/torch_mlir_utils/codegen/aten_ops.py deleted file mode 100644 index 03048ad51..000000000 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/aten_ops.py +++ /dev/null @@ -1,202 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Populates an op registry for ATen ops. - -Typically callers will import and use the 'populate' function to add known -ops to the OpRegistry. When run interactively as a main module, it simply -prints all registered ops. -""" - -from .registry import * - -import torch -import torch.nn.functional as F - - -def populate(r: OpRegistry): - # Unary pointwise ops (ordinary that take out refs). - for f in [ - torch.abs, torch.acos, torch.angle, torch.asin, torch.atan, torch.ceil, - torch.conj, torch.cos, torch.cosh, torch.digamma, torch.erf, torch.erfc, - torch.erfinv, torch.exp, torch.expm1, torch.floor, torch.frac, - torch.lgamma, torch.log, torch.log10, torch.log1p, torch.log2, torch.neg, - torch.reciprocal, torch.round, torch.rsqrt, torch.sigmoid, torch.sign, - torch.sin, torch.sinh, torch.sqrt, torch.tan, torch.tanh, torch.trunc - ]: - r.op(f, TensorValue("input")).with_outref_variant() - - # Unary pointwise ops (that do not take out refs). - for f in [torch.relu]: - r.op(f, TensorValue("input")) - - # Binary pointwise ops. - r.op(torch.add, - TensorValue("input"), - TensorValue("other"), - alpha=LiteralValue()).with_outref_variant() - r.op(torch.atan2, TensorValue("input"), - TensorValue("other")).with_outref_variant() - r.op(torch.div, TensorValue("input"), - TensorValue("other")).with_outref_variant() - r.op(torch.floor_divide, TensorValue("input"), - TensorValue("other")).with_outref_variant() - r.op(torch.mul, TensorValue("input"), - TensorValue("other")).with_outref_variant() - r.op(torch.remainder, TensorValue("input"), - TensorValue("other")).with_outref_variant() - r.op(torch.true_divide, TensorValue("dividend"), - TensorValue("divisor")).with_outref_variant() - - # Aggregate operations. - # TODO: Support the optional dtype= parameter. - r.op(torch.cumsum, TensorValue("input", example_size=(10, 3)), - LiteralValue("dim", value=1)).with_outref_variant() - r.op( - torch.mean, - TensorValue("input", example_size=(10, 3)), # - LiteralValue("dim", value=[0], mlir_ods_predicate="ATen_IntList"), # - LiteralValue("keep_dim", - value=False, - mlir_ods_predicate="ATen_BoolScalar") # - ).with_outref_variant() - r.op( - torch.sum, - TensorValue("input", example_size=(10, 3)), # - LiteralValue("dim", value=[0], mlir_ods_predicate="ATen_IntList"), # - LiteralValue("keep_dim", - value=False, - mlir_ods_predicate="ATen_BoolScalar") # - ).with_outref_variant() - - # Gather. - (r.op( - torch.gather, # - TensorValue("input", example_size=(2, 2)), # - dim=LiteralValue(value=1, mlir_ods_predicate="ATen_IntScalar"), # - index=LiteralValue(value=torch.tensor([[0, 0], [1, 0]]), - mlir_ods_predicate="ATen_AnyTensor"), # - sparse_grad=LiteralValue(value=False, - mlir_ods_predicate="ATen_BoolScalar") # - ).with_outref_variant()) - - # Non-view methods on Tensor. - (r.op((TensorValue(name="input", example_size=(3, 1)), "@T"))) - - # BLAS and LAPACK ops. - r.op(torch.addmm, - TensorValue("input", example_size=(2, 3)), - TensorValue("mat1", example_size=(2, 3)), - TensorValue("mat2", example_size=(3, 3)), - beta=LiteralValue(), - alpha=LiteralValue()).with_outref_variant() - r.op(torch.dot, TensorValue("input", example_size=(10,)), - TensorValue("tensor", example_size=(10,))) - r.op(torch.matmul, TensorValue("input", example_size=(10, 3, 4)), - TensorValue("other", example_size=(4, 5))).with_outref_variant() - r.op(torch.mm, TensorValue("input", example_size=(3, 4)), - TensorValue("mat2", example_size=(4, 6))).with_outref_variant() - - # NN Functional. - # Note that _convolution is a special case and is manually coded. - r.op( - F.hardtanh, - TensorValue("input"), # - min_val=LiteralValue(value=-1.0, - mlir_ods_predicate="ATen_FloatScalar"), # - max_val=LiteralValue(value=1.0, mlir_ods_predicate="ATen_FloatScalar") # - ) - r.op(F.avg_pool1d, - TensorValue("input", example_size=(1, 1, 7)), - kernel_size=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"), - stride=LiteralValue(value=[5], mlir_ods_predicate="ATen_IntList"), - padding=LiteralValue(value=[1], mlir_ods_predicate="ATen_IntList"), - ceil_mode=LiteralValue(value=True, mlir_ods_predicate="ATen_BoolScalar"), - count_include_pad=LiteralValue(value=False, - mlir_ods_predicate="ATen_BoolScalar")) - - # MaxPool1D is split into two ops based on whether return_indices is True: - # aten::max_pool1d -> tensor - # aten::max_pool1d_with_indices -> (tensor, tensor) - # Both have odd signatures and are hand-mapped. - # TODO: Implement max_pool1d(..., with_indices=True) - (r.op(F.max_pool1d, - TensorValue("input", example_size=(1, 1, 7)), - kernel_size=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"), - stride=LiteralValue(value=[5], mlir_ods_predicate="ATen_IntList"), - padding=LiteralValue(value=[1], mlir_ods_predicate="ATen_IntList"), - dilation=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"), - ceil_mode=LiteralValue(value=True, - mlir_ods_predicate="ATen_BoolScalar"), - return_indices=LiteralValue(value=False)) # - .with_torch_op_kind("aten::max_pool1d") # - .with_operand_map("input", "kernel_size", "stride", "padding", "dilation", - "ceil_mode") # - ) - - # View ops. - # TODO: All of these need special analysis and should be parameterized - # on mutable tensors and have a proper design thought through. For now, - # even having them in the inventory (badly) increases visibility. - (r.op(torch.as_strided, - TensorValue("input"), - size=LiteralValue(value=[2, 2], mlir_ods_predicate="ATen_IntList"), - stride=LiteralValue(value=[1, 2], mlir_ods_predicate="ATen_IntList"), - storage_offset=LiteralValue(value=4, - mlir_ods_predicate="ATen_IntScalar")) # - .with_append_description(r""" - - MLIR Specific Notes - ------------------- - In PyTorch proper, this op creates a view that may internally alias. And - have explicit warnings about avoiding inplace updates on such a - view (without first cloning). For the moment, this op is formulated with - value semantics that imply a copy instead of a view, and it is expected - that any sharing can be recovered later by the compiler. The warning - about not in-place updating of such a result should be treated as UB - when compiled. - """)) - - (r.op((TensorValue(name="input", example_size=(3, 1)), "expand"), - LiteralValue("sizes", value=torch.Size([3, 4]))) # - .with_operand_map("input", "sizes", - LiteralValue("implicit", - mlir_ods_predicate="ATen_BoolScalar")) # - .with_append_description(r""" - - MLIR Specific Notes - ------------------- - See notes for the 'as_strided' op. - """)) - - (r.op( - torch.squeeze, - TensorValue("input"), # - LiteralValue("dim", value=1, mlir_ods_predicate="ATen_IntScalar") # - ).with_append_description(r""" - - MLIR Specific Notes - ------------------- - See notes for the 'as_strided' op. - """)) - - (r.op((TensorValue(name="input", example_size=(3, 1)), "view"), - LiteralValue(name="size", - value=[3, 1], - mlir_ods_predicate="ATen_IntList")) - .with_append_description(r""" - - MLIR Specific Notes - ------------------- - See notes for the 'as_strided' op. - """)) - - -if __name__ == "__main__": - import logging - logging.basicConfig(level=logging.DEBUG) - registry = OpRegistry() - populate(registry) - print("Registered operations:") - for m in registry.mappings: - print(" ", m) diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/generate_ods.py b/frontends/pytorch/python/torch_mlir_utils/codegen/generate_ods.py deleted file mode 100644 index 0b88702d6..000000000 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/generate_ods.py +++ /dev/null @@ -1,205 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Generates ODS for a registry of ops.""" - -from typing import TextIO - -import argparse -from contextlib import contextmanager -import importlib -import logging -import re -import sys -import textwrap - -from .registry import * - -_INDENT = " " - - -class OdsEmitter: - ods_prefix = "ATen_" - ods_suffix = "Op" - ods_value_template = "ATen_ImmutableTensorOp" - ods_ref_template = "ATen_RefTensorOp" - op_prefix = "" - - def __init__(self, r: OpRegistry, out: TextIO): - super().__init__() - self.r = r - self.out = out - self.indent_level = 0 - - def emit_ods(self): - for op_m in self.r.mappings: - if isinstance(op_m, SimpleOpMapping): - self._emit_simple_op_mapping(op_m) - else: - logging.warn(f"Unrecognized op mapping type: {op_m!r}") - - def _emit_simple_op_mapping(self, op_m: SimpleOpMapping): - identifier = (f"{self.ods_prefix}" - f"{_snakecase_to_camelcase(op_m.mlir_operation_name)}" - f"{self.ods_suffix}") - traits = [] - - if op_m.is_outref_form: - template_name = self.ods_ref_template - summary = "See non-inplace op variant." - description = "" - else: - template_name = self.ods_value_template - summary, description = _split_docstring(op_m.op_f.__doc__) - - if not op_m.is_outref_form: - traits.append("NoSideEffect") - self.print(f"def {identifier}: {template_name}" - f"<{_quote(op_m.mlir_operation_name)}, [" - f"{', '.join(traits)}" - f"]> {{") - - # Summary. - with self.indent(): - self.print(f"let summary = {_quote(summary)};") - - # Arguments. - with self.indent(): - self.print("let arguments = (ins") - with self.indent(): - operand_len = len(op_m.operand_map) - for index, (_, value_spec) in enumerate(op_m.operand_map): - is_last = index == operand_len - 1 - self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}", - end="\n" if is_last else ",\n") - self.print(");") - - # Results (omitted if an outref/inplace form). - with self.indent(): - if op_m.is_outref_form: - self.print("let results = (outs);") - else: - self.print("let results = (outs") - with self.indent(): - result_len = len(op_m.result_map) - for index, (_, value_spec) in enumerate(op_m.result_map): - is_last = index == result_len - 1 - self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}", - end="\n" if is_last else ",\n") - self.print(");") - - # Description and extra class declarations. - with self.indent(): - if description: - quoted_description = _quote_multiline_docstring( - description + op_m.append_description, - indent_level=self.indent_level) - self.print(f"let description = {quoted_description};") - - self.print("}\n") - - @contextmanager - def indent(self, level=1): - self.indent_level += level - yield - self.indent_level -= level - assert self.indent_level >= 0, "Unbalanced indentation" - - def print(self, s, *, end="\n", indent=True): - if indent and self.indent_level: - self.out.write(_INDENT * self.indent_level) - self.out.write(s) - self.out.write(end) - - -def _snakecase_to_camelcase(ident: str): - return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident)) - - -def _quote(s: str): - s = s.replace(r'"', r'\\"') - return f'"{s}"' - - -def _quote_multiline_docstring(s: str, indent_level: int = 0): - # TODO: Possibly find a python module to markdown the docstring for better - # document generation. - # Unlikely to contain the delimitter and since just a docstring, be safe. - s = s.replace("}]", "") - # Strip each line. - s = "\n".join([l.rstrip() for l in s.splitlines()]) - indent = _INDENT * indent_level - s = textwrap.indent(s, indent + _INDENT) - return "[{\n" + s + "\n" + indent + "}]" - - -def _split_docstring(docstring: str): - """Splits the docstring into a summary and description.""" - if not docstring: - docstring = "" - lines = docstring.splitlines() - if not lines: - return "", "" - - # Skip leading blank lines. - while lines and not lines[0]: - lines = lines[1:] - if len(lines) > 2: - return lines[0], "\n".join(lines[2:]) - else: - return lines[0], "" - - -def main(args): - r = OpRegistry() - # Populate from modules that provide a populate() function. - op_modules = [args.op_module] - for m_name in op_modules: - logging.info(f"Populating from module: {m_name}") - m = importlib.import_module(m_name, package=__package__) - f = getattr(m, "populate") - f(r) - - out = sys.stdout - - # Write file header. - module_name = sys.modules["__main__"].__loader__.name - banner_lines = [ - "//===-------------------------------------------------------*- tablegen -*-===//", - "//", - "// This file is licensed under the Apache License v2.0 with LLVM Exceptions.", - "// See https://llvm.org/LICENSE.txt for license information.", - "// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception", - "//", - "// Operation summaries and descriptions were systematically derived from public", - "// API docstrings and are licensed accordingly:", - "// https://github.com/pytorch/pytorch/blob/master/LICENSE", - "//===----------------------------------------------------------------------===//", - "// This file is automatically generated. Please do not edit.", - "// Generated via:", - f"// python -m {module_name} {' '.join(sys.argv[1:])}", - "//===----------------------------------------------------------------------===//", - "", - "", - ] - banner_lines = [l.strip() for l in banner_lines] - out.write("\n".join(banner_lines)) - - emitter = OdsEmitter(r, out=out) - emitter.emit_ods() - - -def _create_argparse(): - parser = argparse.ArgumentParser(prog="generate_ods") - parser.add_argument( - "--op_module", - default=".aten_ops", - help="Name of a python module for populating the registry") - return parser - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - parser = _create_argparse() - args = parser.parse_args() - main(args) diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/registry.py b/frontends/pytorch/python/torch_mlir_utils/codegen/registry.py deleted file mode 100644 index 4550e6df8..000000000 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/registry.py +++ /dev/null @@ -1,495 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Base classes and interfaces for mapping ops to MLIR. - -The goal of this facility is to define the majority of op mappings by example, -from the actual Python invocation, using the PyTorch tracer to extract its -Graph IR, and introspecting that to determine mappings, metadata and types -for corresponding MLIR definitions. This is not meant to cover everything, -and it is expected that a number of hard ops should be mapped by hand. - -The result of building such an OpRegistry should be a data structure that -can be used to generate ODS and tables for doing systematic imports from -a PyTorch Graph to corresponding MLIR module. - -Example usage (fully automatic discovery): - - >>> r = OpRegistry() - >>> r.op(torch.add, - TensorValue("input"), - TensorValue("other"), - alpha=ScalarValue()).with_outref_variant() -""" - -from typing import Dict, List, Optional, Sequence, Tuple - -import logging -import random -import textwrap -import torch - -__all__ = [ - "SimpleOpMapping", - "OpRegistry", - "LiteralValue", - "TensorOutRef", - "TensorValue", -] - - -def _is_same_value(value1, value2): - # Tensors are considered via reference equality. - value1_tensor = isinstance(value1, torch.Tensor) - value2_tensor = isinstance(value2, torch.Tensor) - if value1_tensor or value2_tensor: - return value1 is value2 - - # Everything else is value equality. - return value1 == value2 - - -def _extract_immediate(node): - """Extracts an immediate value from a node. - - Supported node types: - prim::Constant - prim::ListConstruct - """ - # Constant - if node.kind() == "prim::Constant": - # Try to extract as different types. - try: - return node.t("value") - except RuntimeError: - pass - try: - return node.f("value") - except RuntimeError: - pass - try: - return node.i("value") - except RuntimeError: - pass - try: - return node.s("value") - except RuntimeError: - pass - return None - # List - elif node.kind() == "prim::ListConstruct": - return [_extract_immediate(i.node()) for i in node.inputs()] - else: - raise ValueError("Unrecognized immediate input node: {!r}".format(node)) - - -def _read_attr_callback(instance, attr_name): - - def f(): - return getattr(instance, attr_name) - - f.__doc__ = getattr(type(instance), attr_name).__doc__ - return f - - -class ValueSpec: - """Base class for inputs to operations. - - This binds information about how the input is mapped to the MLIR operation. - """ - - def __init__(self, name=None, mlir_ods_predicate: str = "AnyType"): - super().__init__() - self.name = name - self.mlir_ods_predicate = mlir_ods_predicate - - def generate_example(self, index=0): - """Generates an example value.""" - raise NotImplementedError() - - def __repr__(self): - return "{}({!r})".format(self.__class__.__name__, self.name) - - -class TensorValue(ValueSpec): - """An input that is a tensor.""" - - def __init__(self, - name=None, - *, - example_size=None, - mlir_ods_predicate="ATen_AnyTensor", - **kwargs): - super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs) - if example_size is None: - example_size = (2, 3, 7) # No significance. - self.example_size = example_size - - def generate_example(self, index=0): - return torch.rand(*self.example_size) - - -class TensorOutRef(ValueSpec): - """A tensor that is passed by ref as an out parameter.""" - - def __init__(self, - name=None, - *, - example_size=None, - mlir_ods_predicate="ATen_AnyRefTensor", - **kwargs): - super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs) - if example_size is None: - example_size = (2, 3, 7) # No significance. - self.example_size = example_size - - def generate_example(self, index=0): - return torch.rand(*self.example_size) - - -class LiteralValue(ValueSpec): - """An input that is a literal value.""" - - def __init__(self, - name=None, - value=None, - mlir_ods_predicate="ATen_AnyScalar", - **kwargs): - super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs) - self.value = value - - def generate_example(self, index=0): - if self.value is not None: - return self.value - return 1.0 + index # Generates a stable value. - - -class OpMapping: - """Base class for things purporting to map an operation.""" - __slots__ = [] - - -class SimpleOpMapping(OpMapping): - """Maps a PyTorch invocation to its MLIR representation.""" - __slots__ = [ - "append_description", - "mlir_operation_name", - "operand_map", - "op_args", - "op_arity", - "op_f", - "op_kind", - "op_kwargs", - "op_self", - "op_self_value_spec", - "outref_variant_value", - "result_map", - ] - - def __init__(self, op_f, *op_args, **op_kwargs): - super().__init__() - self.op_f = op_f - self.op_self = None - self.op_self_value_spec = None - self.op_args = op_args - self.op_kwargs = op_kwargs - self.outref_variant_value = None # type: Optional[TensorOutRef] - self.append_description = "" - - # Set after finalize. - self.op_kind = None # type: Optional[str] - self.op_arity = -1 # type: int - self.operand_map = None # type: Optional[List[Tuple[int, ValueSpec]]] - self.result_map = None # type: Optional[List[Tuple[int, ValueSpec]]] - self.mlir_operation_name = None # type: Optional[str] - - # Fixup self-calls. - # These are specified as tuples of (ValueSpec, "method_name") - if isinstance(self.op_f, tuple): - assert len(self.op_f) == 2, "Self-calls must be a 2-tuple" - self.op_self_value_spec, method_name = self.op_f - self.op_self = self.op_self_value_spec.generate_example(index=-1) - if method_name.startswith("@"): - # Create a pseudo-function to resolve the attribute. - self.op_f = _read_attr_callback(self.op_self, method_name[1:]) - else: - self.op_f = getattr(self.op_self, method_name) - - def __repr__(self): - return ("SimpleOp({kind!r}[{arity}] -> {name!s}, operands={operands!r}, " - "results={results!r})".format(kind=self.op_kind, - arity=self.op_arity, - name=self.mlir_operation_name, - operands=self.operand_map, - results=self.result_map)) - - def clone(self) -> "SimpleOpMapping": - copy = SimpleOpMapping(self.op_f, *self.op_args, **self.op_kwargs) - for name in [ - "outref_variant_value", "op_kind", "op_arity", "operand_map", - "result_map", "mlir_operation_name" - ]: - setattr(copy, name, getattr(self, name)) - return copy - - def with_outref_variant(self, value=None) -> "SimpleOpMapping": - """Instructs the registry to also generate an outref variant. - - This is done by cloning the op prior to finalizing and adding an out= - paramer. - """ - self.outref_variant_value = TensorOutRef() if value is None else value - return self - - def with_torch_op_kind(self, op_kind: str) -> "SimpleOpMapping": - """Uses a manually specified op kind (i.e. "aten::some_op").""" - self.op_kind = op_kind - return self - - def with_operand_map(self, *mlir_operand_names): - """Manually maps torch IR operands to mlir operand names.""" - operand_map = [] - for i, name_or_value_spec in enumerate(mlir_operand_names): - if isinstance(name_or_value_spec, ValueSpec): - value_spec = name_or_value_spec - else: - value_spec = self.find_value_spec_by_name(name_or_value_spec) - operand_map.append((i, value_spec)) - self.operand_map = operand_map - return self - - def with_append_description(self, description) -> "SimpleOpMapping": - """Appends a description to the ODS.""" - self.append_description += textwrap.dedent(description) - return self - - @property - def all_arg_values(self) -> List[ValueSpec]: - """Returns all arg values (either positional or kw).""" - args = list(self.op_args) + list(self.op_kwargs.values()) - if self.op_self_value_spec: - args = [self.op_self_value_spec] + args - return args - - @property - def is_outref_form(self) -> bool: - """Whether the op contains an out parameter that aliases to the result.""" - return any(isinstance(a, TensorOutRef) for a in self.all_arg_values) - - def find_value_spec_by_name(self, name) -> ValueSpec: - """Finds an argument ValueSpec by name. - - Raises: - ValueError if not found. - """ - if self.op_self_value_spec and self.op_self_value_spec.name == name: - return self.op_self_value_spec - value_spec = self.op_kwargs.get(name) - if value_spec: - return value_spec - for value_spec in self.op_args: - if value_spec.name == name: - return value_spec - raise ValueError(f"Unknown value spec: {name}") - - def generate_example(self) -> Tuple[Tuple, Dict]: - """Generates an example signature for invoking the op. - - Returns: - (tuple, dict) of positional and keyword args. - """ - index = 0 - positional = list() - kw = dict() - for op_arg in self.op_args: - positional.append(op_arg.generate_example(index)) - index += 1 - for kw_name, kw_value in self.op_kwargs.items(): - kw[kw_name] = kw_value.generate_example(index) - index += 1 - return positional, kw - - def finalize(self): - """Finalizes the mapping once all hints have been applied.""" - # Update the name on all args if undefined. - for index, op_arg in enumerate(self.op_args): - if op_arg.name is None: - op_arg.name = "arg%d".format(index) - for key, op_arg in self.op_kwargs.items(): - if op_arg.name is None: - op_arg.name = key - - # Create an example graph and configure from it. - self._configure_from_example() - - # Determine default operation name. - if self.mlir_operation_name is None: - self._set_default_mlir_operation_name() - - def _set_default_mlir_operation_name(self): - op_ns, op_name = self.op_kind.split("::", maxsplit=1) - # Since these are emitted into the "aten" dialect namespace, alias them - # to omit the prefix to distinguish from custom ops and others (which will - # have a prefix). - default_name = op_name if op_ns == "aten" else op_ns + "." + op_name - if op_ns == "aten": - default_name = op_name - else: - default_name = op_ns + "." + op_name - - if self.is_outref_form: - default_name += ".inplace" - self.mlir_operation_name = default_name - - def _configure_from_example(self): - # Trace the op so that we get an example graph like this: - # %0 : Float(2, 3, 7) = prim::Constant[value=]() - # %1 : Float(2, 3, 7) = prim::Constant[value=]() - # %2 : float = prim::Constant[value=3.]() - # %3 : Float(2, 3, 7) = aten::add(%0, %1, %2) - # return (%3) - # The def of the return value is expected to be the modeled op. The - # inputs to that op are expected to be captured constants that can be - # re-associated to the example inputs. - example_args, example_kwargs = self.generate_example() - - def forward(): - # logging.debug( - # f"Invoke {self.op_f} with *{example_args}, **{example_kwargs}") - return self.op_f(*example_args, **example_kwargs) - - trace = torch.jit.trace(forward, tuple()) - graph = trace.graph - logging.debug("Graph for op %r: %s", self.op_f, graph) - - # Track up from the return node and assume this is our op. - return_node = graph.return_node() - return_inputs = list(return_node.inputs()) - assert len(return_inputs) == 1, "Expected one input return" - op_node = return_inputs[0].node() - op_inputs = list(op_node.inputs()) - logging.debug("Found op node: %r", op_node) - - # Meta-data about the source op. - self.op_kind = op_node.kind() - self.op_arity = len(op_inputs) - if self.operand_map is None: - self.operand_map = self._associate_inputs(op_inputs, example_args, - example_kwargs) - - # Results. - op_outputs = list(op_node.outputs()) - if self.result_map is None: - if self.is_outref_form: - # Only support single outref results. - assert len(op_outputs) == 1, ( - "For outref ops, only a single output is supported") - self.result_map = [(0, TensorOutRef("result"))] - else: - # Map results in order. - self.result_map = [] - - def result_name(i): - if len(op_outputs) == 1: - return "result" - else: - return "result%d" % i - - for i, op_output in enumerate(op_outputs): - op_output_type = op_output.type() - if issubclass(type(op_output_type), torch.TensorType): - self.result_map.append((i, TensorValue(result_name(i)))) - else: - raise ValueError( - "Unsupported op output type: {!r}".format(op_output_type)) - return self - - def _associate_inputs(self, op_inputs, example_args, example_kwargs): - """Given inputs to a graph op node, associates to the input args. - - This will match up example arguments with what was produced in the graph, - setting the operand_map. - - Returns: - List of (input_index, ValueSpec) mapping inputs to the graph node to - provided values in the op definition. - """ - assert len(example_args) == len(self.op_args) - assert example_kwargs.keys() == self.op_kwargs.keys() - - def find_arg(value): - if self.op_self is not None and _is_same_value(self.op_self, value): - return self.op_self_value_spec - for i, arg in enumerate(example_args): - if _is_same_value(arg, value): - return self.op_args[i] - for key, arg in example_kwargs.items(): - if _is_same_value(arg, value): - return self.op_kwargs[key] - - raise KeyError("Op input not in arguments: {!r} -> {!r}".format( - value, op_inputs)) - - operand_map = [] - for i, op_input in enumerate(op_inputs): - input_node = op_input.node() - immediate_value = _extract_immediate(input_node) - if immediate_value is not None: - operand_map.append((i, find_arg(immediate_value))) - return operand_map - - -class OpRegistry: - """Maintains a registry of op mappings.""" - - def __init__(self): - super().__init__() - self._mappings = [] - self._pending_mapping = None - - def op(self, op_f, *op_args, **op_kwargs): - """Forwards to the SimpleOpMapping constructor and adds it. - - The mapping is not finalized until either the registry is finalized or the - next op mapping is added. This allows tweaks to the mapping to be done - inline prior to performing detailed introspection. - - Returns: - The SimpleOpMapping instance. - """ - self._finalize_pending() - m = SimpleOpMapping(op_f, *op_args, **op_kwargs) - self._pending_mapping = m - return m - - @property - def mappings(self) -> Sequence[OpMapping]: - """Returns the list of OpMapping. - - Returns: - Sequence of OpMapping concrete classes (most commonly SimpleOpMapping). - """ - self._finalize_pending() - return self._mappings - - def _finalize_pending(self): - if not self._pending_mapping: - return - - outref_mapping = None - pending_mapping = self._pending_mapping - self._pending_mapping = None - if pending_mapping.outref_variant_value: - # Generate a variant (with an out= form). - outref_mapping = pending_mapping.clone() - outref_mapping.op_kwargs["out"] = outref_mapping.outref_variant_value - outref_mapping.outref_variant_value = None - - # Finalize the original. - pending_mapping.finalize() - self._mappings.append(pending_mapping) - - # Finalize the outref form if generated. - if outref_mapping: - outref_mapping.finalize() - self._mappings.append(outref_mapping) diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py new file mode 100644 index 000000000..1e4b06349 --- /dev/null +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py @@ -0,0 +1,477 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Queries the pytorch op registry and generates ODS and CC sources for the ops. +""" + +from typing import Dict, List, Optional, TextIO, Sequence, Tuple + +import argparse +from contextlib import contextmanager +import importlib +import logging +import re +import sys +import textwrap + +# Note that this utility exists only in the c-extension. +from _torch_mlir import get_registered_ops + + +def _create_argparse(): + parser = argparse.ArgumentParser(prog="generate_ods") + parser.add_argument("--ods_td_file", + required=True, + help="File to write the generated ODS to") + parser.add_argument("--ods_impl_file", + required=True, + help="CC include file to include in ops implementation") + parser.add_argument("--debug_op_reg_file", + help="Write out a file of op registrations") + return parser + + +def main(args): + reg_ops = _load_ops_as_dict() + if args.debug_op_reg_file: + with open(args.debug_op_reg_file, "w") as debug_ops_file: + dump_registered_ops(debug_ops_file, reg_ops) + + with open(args.ods_td_file, "w") as ods_file, open(args.ods_impl_file, + "w") as impl_file: + ods_emitter = OdsEmitter(ods_file) + ods_emitter.print(ODS_BANNER) + impl_emitter = CCImplEmitter(impl_file) + impl_emitter.print(CC_IMPL_BANNER) + generator = OpGenerator(reg_ops, ods_emitter, impl_emitter) + generate_ops(generator) + + +def generate_ops(g: "OpGenerator"): + # "Binary"-ops. + # There are some variation in these, so we spell them out in case if they + # need individual customization. + g.print_banner("Binary arithmetic ops") + g.ordinary_binary_op("aten::add(Tensor,Tensor,Scalar)", "AddOp", "add") + g.ordinary_binary_op("aten::atan2(Tensor,Tensor)", "Atan2Op", "atan2") + g.ordinary_binary_op("aten::div(Tensor,Tensor)", "DivOp", "div") + g.ordinary_binary_op("aten::floor_divide(Tensor,Tensor)", "FloorDivideOp", + "floor_divide") + g.ordinary_binary_op("aten::mul(Tensor,Tensor)", "MulOp", "mul") + g.ordinary_binary_op("aten::remainder(Tensor,Tensor)", "RemainderOp", + "remainder") + g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp", + "true_divide") + + # Unary-ops. These are all the same so just name munge them. + g.print_banner("Unary arithmetic ops") + for uname in [ + "abs", "acos", "angle", "asin", "atan", "ceil", "conj", "cos", "cosh", + "digamma", "erf", "erfc", "erfinv", "exp", "expm1", "floor", "frac", + "lgamma", "log", "log10", "log1p", "log2", "neg", "relu", "reciprocal", + "round", "rsqrt", "sigmoid", "sign", "sin", "sinh", "sqrt", "tan", "tanh", + "trunc" + ]: + g.ordinary_unary_op(f"aten::{uname}(Tensor)", + f"{snakecase_to_camelcase(uname)}Op", uname) + + +def dump_registered_ops(outfile, reg_ops_dict): + for k in sorted(reg_ops_dict.keys()): + attr_dict = reg_ops_dict[k] + outfile.write(f"OP '{k}':\n") + for attr_name, attr_value in attr_dict.items(): + outfile.write(f" {attr_name} = {attr_value!r}\n") + outfile.write("\n") + + +class OpGenerator: + + def __init__(self, reg_ops, ods_emitter, impl_emitter): + super().__init__() + self.reg_ops = reg_ops + self.ods_emitter = ods_emitter + self.impl_emitter = impl_emitter + + def print_banner(self, text): + for em in (self.ods_emitter, self.impl_emitter): + em.print( + "// -----------------------------------------------------------------------------" + ) + em.print(f"// {text}") + em.print( + "// -----------------------------------------------------------------------------" + ) + em.print("") + + def ordinary_binary_op(self, kernel_sig, ods_name, op_name): + """"Binary"-ops. These ops typically have: + - '.Tensor' variant where the second arg is a Tensor + - '.Scalar' variant where the second arg is a Scalar + - An '.out' variant which contains a final "outref" argument + Actual suffixes vary and the rules are set up to match anything + that comes in in one of these forms. + In addition, most of these have in-place versions, possibly of the + '.Tensor' and '.Scalar' variants above. + Note that many of these have more than two arguments (i.e. alpha/beta + scalars trailing and such), but they are not relevant for + matching/conversions (if they are more than pass-through, then have a + dedicated rule). + We generally canonicalize all of these forms to a single recognized op + by: + - Enabling the flag to promotTrailingOutTensor + - Enabling the flag matchInplaceVariant + - Setting all arguments and returns to kImmutableTensor + - Enabling kPromoteScalarToTensor on the second argument. + """ + reg_record = self._get_reg_record(kernel_sig) + ods_ins, arg_type_flags = self._map_sigtypes( + reg_record["arguments"], + type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + "Tensor:1": "AnyTorchImmutableTensor", + "Scalar:1": "AnyTorchImmutableTensor", + "Scalar": "AnyTorchScalarType", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + ":1": ["kImmutableTensor", "kPromoteScalar"], + }) + ods_outs, return_type_flags = self._map_sigtypes( + reg_record["returns"], + type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + }) + self.ods_emitter.emit_opdef(ods_name, + op_name, + reg_record, + ods_ins=ods_ins, + ods_outs=ods_outs, + traits=["NoSideEffect"]) + self.impl_emitter.emit_kernel_methods(ods_name, + reg_record, + arg_type_flags=arg_type_flags, + return_type_flags=return_type_flags, + promote_trailing_out_tensor=True) + + def ordinary_unary_op(self, kernel_sig, ods_name, op_name): + """Unary ops. + + These take and return a tensor and typically have an out and inplace + variant (they may not but we generate patterns to match anyway). + """ + reg_record = self._get_reg_record(kernel_sig) + ods_ins, arg_type_flags = self._map_sigtypes( + reg_record["arguments"], + type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + }) + ods_outs, return_type_flags = self._map_sigtypes( + reg_record["returns"], + type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + }) + self.ods_emitter.emit_opdef(ods_name, + op_name, + reg_record, + ods_ins=ods_ins, + ods_outs=ods_outs, + traits=["NoSideEffect"]) + self.impl_emitter.emit_kernel_methods(ods_name, + reg_record, + arg_type_flags=arg_type_flags, + return_type_flags=return_type_flags, + promote_trailing_out_tensor=True) + + def _get_reg_record(self, kernel_sig): + """Gets the op-dict for a given registered op name. + + Args: + kernel_sig: Signature of the kernel to find. + Returns: + Dict of the registration record. + """ + record = self.reg_ops.get(kernel_sig) + if record: + return record + + # Try to give a nice "did you mean" style error, since this happens + # so much. + kernel_name, *rest = kernel_sig.split("(", maxsplit=1) + dym_list = [k for k in self.reg_ops.keys() if k.startswith(kernel_name)] + dym_message = '\n '.join(dym_list) + raise ValueError(f"Could not find registry op matching '{kernel_sig}'. " + f"Possible matches:\n {dym_message}") + + def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str], + flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]: + """Maps a list of signature entries to ods dags and flag lists. + + The torch signature list contains dicts that minimally have keys 'name' and + 'type', representing torch names and types. Returns a corresponding + list of 2-tuples of (ods_name, ods_type). + + The provided type_transforms is a dict of type substitutions, one of which + must match for each entry in the list. The keys can be either a verbatim + torch type (i.e. "Tensor") an index in the list (i.e. ":0") or a type and + index (i.e. "Tensor:0"). + + Similarly, flag_transforms matches its keys in the same way and maps to + a list of KernelValueConversion constants that make up arg/return specific + conversion flags. + + Returns: + - An ods dag list of (ods_name, ods_type) tuples + - List of (torch_type, [conversion_flag]) for specifying conversions. + """ + # Generate to ods dag list. + ods_dag_list = [] + for i, sigitem in enumerate(siglist): + torch_name = sigitem["name"] + torch_type = sigitem["type"] + # Look up the type transform. + ods_type = (type_transforms.get(f"{torch_type}:{i}") or + type_transforms.get(f":{i}") or + type_transforms.get(torch_type)) + if not ods_type: + raise ValueError(f"Signature item {i}, type {torch_type} did not match " + f"a type transform {type_transforms}") + ods_dag_list.append((torch_name, ods_type)) + + # Generate the type conversion flags. + type_flag_list = [] + for i, sigitem in enumerate(siglist): + torch_type = sigitem["type"] + # Look up the type transform. + flags = (flag_transforms.get(f"{torch_type}:{i}") or + flag_transforms.get(f":{i}") or flag_transforms.get(torch_type)) + if not flags: + flags = [] + type_flag_list.append((torch_type, flags)) + return ods_dag_list, type_flag_list + + +class EmitterBase: + _INDENT = " " + + def __init__(self, out: TextIO): + super().__init__() + self.out = out + self.indent_level = 0 + + @contextmanager + def indent(self, level=1): + self.indent_level += level + yield + self.indent_level -= level + assert self.indent_level >= 0, "Unbalanced indentation" + + def print(self, s, *, end="\n", indent=True): + if indent and self.indent_level: + self.out.write(self._INDENT * self.indent_level) + self.out.write(s) + self.out.write(end) + + def quote(self, s: str): + s = s.replace(r'"', r'\\"') + return f'"{s}"' + + def quote_multiline_docstring(self, s: str, indent_level: int = 0): + # TODO: Possibly find a python module to markdown the docstring for better + # document generation. + # Unlikely to contain the delimitter and since just a docstring, be safe. + s = s.replace("}]", "") + # Strip each line. + s = "\n".join([l.rstrip() for l in s.splitlines()]) + indent = self._INDENT * indent_level + s = textwrap.indent(s, indent + self._INDENT) + return "[{\n" + s + "\n" + indent + "}]" + + +class OdsEmitter(EmitterBase): + ods_def_prefix = "aten_" + ods_def_suffix = "" + ods_template_name = "aten_Op" + + def emit_opdef(self, + ods_def_name: str, + mnemonic: str, + reg_record: Dict, + ods_ins: List[Tuple[str, str]], + ods_outs: List[Tuple[str, str]], + traits: Sequence[str] = (), + summary: Optional[str] = None): + # Def first-line. + full_traits = list(traits) + full_traits.append( + "DeclareOpInterfaceMethods") + full_traits.append("DeclareOpInterfaceMethods") + identifier = f"{self.ods_def_prefix}{ods_def_name}{self.ods_def_suffix}" + self.print(f"def {identifier}: {self.ods_template_name}" + f"<{self.quote(mnemonic)}, [" + f"{', '.join(full_traits)}" + f"]> {{") + with self.indent(): + # Summary. + if not summary: + summary = f"Recognized op for kernel {reg_record['name'][0]}" + self.print(f"let summary = {self.quote(summary)};") + # Arguments. + self.print("let arguments = (ins") + with self.indent(): + self._emit_dag_list_body(ods_ins) + self.print(");") + + # Results. + self.print("let results = (outs") + with self.indent(): + self._emit_dag_list_body(ods_outs) + self.print(");") + + # Def last-line. + self.print("}\n") + + def _emit_dag_list_body(self, items): + """Emits a dag of (name, type) pairs.""" + for index, (ods_name, ods_type) in enumerate(items): + is_last = index == len(items) - 1 + ods_namespec = f":${ods_name}" if ods_name else "" + self.print(f"{ods_type}{ods_namespec}", end="\n" if is_last else ",\n") + + +class CCImplEmitter(EmitterBase): + + def emit_kernel_methods(self, + ods_def_name: str, + reg_record, + arg_type_flags: List[Tuple[str, List[Tuple[str]]]], + return_type_flags: List[Tuple[str, List[Tuple[str]]]], + promote_trailing_out_tensor=False): + # getTorchKernelMetadata() method. + self.print( + f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{") + with self.indent(): + self.print("return getTorchBuildKernelMetadata();") + self.print("}\n") + + # getTorchBuildKernelMetadata() method. + kernel_name = reg_record["name"][0] + self.print( + f"const Torch::BuildKernelMetadata &{ods_def_name}::getTorchBuildKernelMetadata() {{" + ) + with self.indent(): + self.print("using KVC = Torch::KernelValueConversion::BitMask;") + self.print("static Torch::BuildKernelMetadata metadata = ([]() {") + with self.indent(): + self.print("Torch::BuildKernelMetadata m;") + self.print(f"m.kernelName = {self.quote(kernel_name)};") + if promote_trailing_out_tensor: + self.print("m.promoteTrailingOutTensor = true;") + # Arg types/flags. + arg_types = self._format_cpp_str_initlist( + [t[0] for t in arg_type_flags]) + self.print(f"m.addArgTypes({arg_types});") + arg_flags = self._format_cpp_kvc_initlist( + [t[1] for t in arg_type_flags]) + self.print(f"m.addArgConversions({arg_flags});") + # Returns types/flags. + ret_types = self._format_cpp_str_initlist( + [t[0] for t in return_type_flags]) + self.print(f"m.addReturnTypes({ret_types});") + ret_flags = self._format_cpp_kvc_initlist( + [t[1] for t in return_type_flags]) + self.print(f"m.addReturnConversions({ret_flags});") + self.print("return m;") + self.print("})();") + self.print("return metadata;") + self.print("}") + + def _format_cpp_str_initlist(self, strings): + quoted = [self.quote(s) for s in strings] + joined = ", ".join(quoted) + return "{" + joined + "}" + + def _format_cpp_kvc_initlist(self, const_name_lists): + + def or_flags(flag_names): + if not flag_names: + return "KVC::kNone" + return "|".join([f"KVC::{n}" for n in flag_names]) + + or_d = [or_flags(l) for l in const_name_lists] + joined = ", ".join(or_d) + return "{" + joined + "}" + + +def snakecase_to_camelcase(ident: str): + return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident)) + + +def _load_ops_as_dict(): + # Returns a list of dicts, each with a name that is a tuple of the form: + # (kernel_signature, variant) + # The kernel signature is a reified form of the argument type signature + # used throughout PyTorch: + # namespace::kernel_name(type1,type2) + def reify_signature(reg_op): + kernel_name, unused_variant = reg_op["name"] + arg_types = [arg["type"] for arg in reg_op["arguments"]] + return f"{kernel_name}({','.join(arg_types)})" + + reg_ops_list = get_registered_ops() + return {reify_signature(reg_op): reg_op for reg_op in reg_ops_list} + + +def _get_main_module_name(): + return sys.modules["__main__"].__loader__.name + + +ODS_BANNER = "\n".join([ + "//===-------------------------------------------------------*- tablegen -*-===//", + "//", + "// This file is licensed under the Apache License v2.0 with LLVM Exceptions.", + "// See https://llvm.org/LICENSE.txt for license information.", + "// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception", + "//", + "// Operation summaries and descriptions were systematically derived from public", + "// API docstrings and are licensed accordingly:", + "// https://github.com/pytorch/pytorch/blob/master/LICENSE", + "//===----------------------------------------------------------------------===//", + "// This file is automatically generated. Please do not edit.", + "// Generated via:", + f"// python -m {_get_main_module_name()}", + "//===----------------------------------------------------------------------===//", + "", + "", +]) + +CC_IMPL_BANNER = "\n".join([ + "//===-------------------------------------------------------------*- cc -*-===//", + "//", + "// This file is licensed under the Apache License v2.0 with LLVM Exceptions.", + "// See https://llvm.org/LICENSE.txt for license information.", + "// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception", "//", + "// Operation summaries and descriptions were systematically derived from public", + "// API docstrings and are licensed accordingly:", + "// https://github.com/pytorch/pytorch/blob/master/LICENSE", + "//===----------------------------------------------------------------------===//", + "// This file is automatically generated. Please do not edit.", + "// Generated via:", f"// python -m {_get_main_module_name()}", + "//===----------------------------------------------------------------------===//", + "", "", "// clang-format off" +]) + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + parser = _create_argparse() + args = parser.parse_args() + main(args) diff --git a/include/npcomp/Dialect/ATen/IR/ATenOps.td b/include/npcomp/Dialect/ATen/IR/ATenOps.td index c8ba3a63d..380eae21e 100644 --- a/include/npcomp/Dialect/ATen/IR/ATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/ATenOps.td @@ -25,44 +25,7 @@ class aten_Op traits = [StatisticsOpInterface]> : // Most ops are automatically generated from pytorch specs. include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td" - -def aten_AddOp: aten_Op<"add", [ - NoSideEffect, TorchBuildableKernelOpInterface, TorchKernelOpInterface, - StatisticsOpInterface]> { - let arguments = ( - ins AnyTorchImmutableTensor:$self, - AnyTorchImmutableTensor:$other, - AnyTorchScalarType:$alpha - ); - let results = (outs AnyTorchImmutableTensor); - let summary = "aten add operator"; - let description = [{ - AddOp - aten add operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - - Torch::KernelMetadata getTorchKernelMetadata() { - return getTorchBuildKernelMetadata(); - } - - static const Torch::BuildKernelMetadata &getTorchBuildKernelMetadata() { - using KVC = Torch::KernelValueConversion::BitMask; - static Torch::BuildKernelMetadata metadata = ([]() { - Torch::BuildKernelMetadata m; - m.kernelName = "aten::add"; - m.promoteTrailingOutTensor = true; - m.addArgTypes({"Tensor", "Tensor", "Scalar"}); - m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone}); - m.addReturnTypes({"Tensor"}); - m.addReturnConversions({KVC::kImmutableTensor}); - return m; - })(); - return metadata; - } - }]; -} +include "npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td" def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>, Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> { diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc new file mode 100644 index 000000000..5896a0471 --- /dev/null +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc @@ -0,0 +1,781 @@ +//===-------------------------------------------------------------*- cc -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Operation summaries and descriptions were systematically derived from public +// API docstrings and are licensed accordingly: +// https://github.com/pytorch/pytorch/blob/master/LICENSE +//===----------------------------------------------------------------------===// +// This file is automatically generated. Please do not edit. +// Generated via: +// python -m torch_mlir_utils.codegen.torch_signature_ods_gen +//===----------------------------------------------------------------------===// + + +// clang-format off +// ----------------------------------------------------------------------------- +// Binary arithmetic ops +// ----------------------------------------------------------------------------- + +Torch::KernelMetadata AddOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::add"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Scalar"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar, KVC::kNone}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::atan2"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata DivOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::div"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::floor_divide"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata MulOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::mul"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::remainder"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::true_divide"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +// ----------------------------------------------------------------------------- +// Unary arithmetic ops +// ----------------------------------------------------------------------------- + +Torch::KernelMetadata AbsOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::abs"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata AcosOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::acos"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata AngleOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::angle"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata AsinOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::asin"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata AtanOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::atan"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata CeilOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::ceil"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata ConjOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::conj"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata CosOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::cos"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata CoshOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::cosh"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::digamma"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata ErfOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::erf"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::erfc"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::erfinv"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata ExpOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::exp"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::expm1"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata FloorOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::floor"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata FracOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::frac"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::lgamma"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata LogOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::log"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata Log10Op::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::log10"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::log1p"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata Log2Op::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::log2"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata NegOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::neg"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata ReluOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::relu"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::reciprocal"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata RoundOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::round"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::rsqrt"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::sigmoid"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata SignOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::sign"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata SinOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::sin"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata SinhOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::sinh"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::sqrt"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata TanOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::tan"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata TanhOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::tanh"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} +Torch::KernelMetadata TruncOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::trunc"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td index b4c3e2234..d25a8d1ee 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td @@ -1,716 +1,452 @@ -// This file is (mostly) automatically generated. Please do not edit. -//===- ATenOps.td ------------------------------------------*- tablegen -*-===// +//===-------------------------------------------------------*- tablegen -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // +// Operation summaries and descriptions were systematically derived from public +// API docstrings and are licensed accordingly: +// https://github.com/pytorch/pytorch/blob/master/LICENSE +//===----------------------------------------------------------------------===// +// This file is automatically generated. Please do not edit. +// Generated via: +// python -m torch_mlir_utils.codegen.torch_signature_ods_gen //===----------------------------------------------------------------------===// -#ifndef NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS -#define NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS -def aten_AddUnderOp: aten_Op<"add_", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$other, - AnyScalar:$alpha +// ----------------------------------------------------------------------------- +// Binary arithmetic ops +// ----------------------------------------------------------------------------- + +def aten_AddOp: aten_Op<"add", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::add"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten add_ operator"; - let description = [{ - AddUnderOp - aten add_ operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyType:$size, - AnyType:$stride +def aten_Atan2Op: aten_Op<"atan2", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::atan2"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten as_strided operator"; - let description = [{ - AsStridedOp - aten as_strided operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$input, - AnyTensor:$weight, - AnyTensor:$bias, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation, - AnyScalar:$transposed, - AnyType:$output_padding, - AnyScalar:$groups +def aten_DivOp: aten_Op<"div", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::div"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten convolution_overrideable operator"; - let description = [{ - ConvolutionOverrideableOp - aten convolution_overrideable operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor, AnyTensor, AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$input, - AnyTensor:$weight, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation, - AnyScalar:$transposed, - AnyType:$output_padding, - AnyScalar:$groups +def aten_FloorDivideOp: aten_Op<"floor_divide", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::floor_divide"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten convolution_backward_overrideable operator"; - let description = [{ - ConvolutionBackwardOverrideableOp - aten convolution_backward_overrideable operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_DivOp: aten_Op<"div", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$other +def aten_MulOp: aten_Op<"mul", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::mul"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten div operator"; - let description = [{ - DivOp - aten div operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$other +def aten_RemainderOp: aten_Op<"remainder", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::remainder"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten div_ operator"; - let description = [{ - DivUnderOp - aten div_ operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyType:$size, - AnyScalar:$implicit +def aten_TrueDivideOp: aten_Op<"true_divide", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::true_divide"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten expand operator"; - let description = [{ - ExpandOp - aten expand operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$dim, - AnyScalar:$half_to_float +// ----------------------------------------------------------------------------- +// Unary arithmetic ops +// ----------------------------------------------------------------------------- + +def aten_AbsOp: aten_Op<"abs", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::abs"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten _log_softmax operator"; - let description = [{ - LogSoftmaxOp - aten _log_softmax operator - }]; } -def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$output, - AnyScalar:$dim, - AnyTensor:$self +def aten_AcosOp: aten_Op<"acos", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::acos"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten _log_softmax_backward_data operator"; - let description = [{ - LogSoftmaxBackwardDataOp - aten _log_softmax_backward_data operator - }]; } -def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self +def aten_AngleOp: aten_Op<"angle", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::angle"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten mean operator"; - let description = [{ - MeanOp - aten mean operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_MmOp: aten_Op<"mm", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$mat2 +def aten_AsinOp: aten_Op<"asin", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::asin"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten mm operator"; - let description = [{ - MmOp - aten mm operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_MulOp: aten_Op<"mul", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$other +def aten_AtanOp: aten_Op<"atan", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::atan"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten mul operator"; - let description = [{ - MulOp - aten mul operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_MulUnderOp: aten_Op<"mul_", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$other +def aten_CeilOp: aten_Op<"ceil", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::ceil"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten mul_ operator"; - let description = [{ - MulUnderOp - aten mul_ operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_NativeBatchNormOp: aten_Op<"native_batch_norm", [NoSideEffect, StatisticsOpInterface]>, - // FIXME: not quite automatically generated names - Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> { - let arguments = ( - ins AnyTensor:$input, - AnyTensor:$weight, - AnyTensor:$bias, - AnyTensor:$running_mean, - AnyTensor:$running_var, - AnyScalar:$training, - AnyScalar:$momentum, - AnyScalar:$eps +def aten_ConjOp: aten_Op<"conj", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::conj"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten native_batch_norm operator"; - let description = [{ - NativeBatchNormOp - aten native_batch_norm operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_NativeBatchNormBackwardOp: aten_Op<"native_batch_norm_backward", [NoSideEffect, StatisticsOpInterface]>, - // FIXME: not quite automatically generated - Results<(outs AnyTensor:$dx, AnyTensor:$dm, AnyTensor:$dv)> { -let arguments = ( - ins AnyTensor:$grad_out, - AnyTensor:$input, - AnyTensor:$weight, - AnyTensor:$running_mean, - AnyTensor:$running_var, - AnyTensor:$save_mean, - AnyTensor:$save_invstd, - AnyScalar:$train, - AnyScalar:$eps +def aten_CosOp: aten_Op<"cos", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::cos"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten native_batch_norm_backward operator"; - let description = [{ - NativeBatchNormBackwardOp - aten native_batch_norm_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_NegOp: aten_Op<"neg", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self +def aten_CoshOp: aten_Op<"cosh", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::cosh"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten neg operator"; - let description = [{ - NegOp - aten neg operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_ReluOp: aten_Op<"relu", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self +def aten_DigammaOp: aten_Op<"digamma", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::digamma"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten relu operator"; - let description = [{ - ReluOp - aten relu operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self +def aten_ErfOp: aten_Op<"erf", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::erf"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten relu_ operator"; - let description = [{ - ReluUnderOp - aten relu_ operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$dim +def aten_ErfcOp: aten_Op<"erfc", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::erfc"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten size operator"; - let description = [{ - SizeOp - aten size operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$dim +def aten_ErfinvOp: aten_Op<"erfinv", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::erfinv"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten squeeze operator"; - let description = [{ - SqueezeOp - aten squeeze operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_SumOp: aten_Op<"sum", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyType:$dim, - AnyScalar:$keepdim +def aten_ExpOp: aten_Op<"exp", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::exp"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten sum operator"; - let description = [{ - SumOp - aten sum operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_TOp: aten_Op<"t", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self +def aten_Expm1Op: aten_Op<"expm1", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::expm1"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten t operator"; - let description = [{ - TOp - aten t operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_ThresholdBackwardOp: aten_Op<"threshold_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self, - AnyScalar:$threshold +def aten_FloorOp: aten_Op<"floor", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::floor"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten threshold_backward operator"; - let description = [{ - ThresholdBackwardOp - aten threshold_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_UnsqueezeOp: aten_Op<"unsqueeze", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$dim +def aten_FracOp: aten_Op<"frac", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::frac"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten unsqueeze operator"; - let description = [{ - UnsqueezeOp - aten unsqueeze operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_SubOp: aten_Op<"sub", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$other, - AnyScalar:$alpha +def aten_LgammaOp: aten_Op<"lgamma", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::lgamma"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten sub operator"; - let description = [{ - SubOp - aten sub operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_SubUnderOp: aten_Op<"sub_", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$other, - AnyScalar:$alpha +def aten_LogOp: aten_Op<"log", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::log"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten sub_ operator"; - let description = [{ - SubUnderOp - aten sub_ operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_AddmmOp: aten_Op<"addmm", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$mat1, - AnyTensor:$mat2, - AnyScalar:$beta, - AnyScalar:$alpha +def aten_Log10Op: aten_Op<"log10", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::log10"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten addmm operator"; - let description = [{ - AddmmOp - aten addmm operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_ViewOp: aten_Op<"view", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyType:$size +def aten_Log1pOp: aten_Op<"log1p", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::log1p"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten view operator"; - let description = [{ - ViewOp - aten view operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$dim, - AnyTensor:$index, - AnyScalar:$sparse_grad +def aten_Log2Op: aten_Op<"log2", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::log2"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten gather operator"; - let description = [{ - GatherOp - aten gather operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor, AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index +def aten_NegOp: aten_Op<"neg", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::neg"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten nll_loss_forward operator"; - let description = [{ - NllLossForwardOp - aten nll_loss_forward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index, - AnyTensor:$total_weight +def aten_ReluOp: aten_Op<"relu", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::relu"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten nll_loss_backward operator"; - let description = [{ - NllLossBackwardOp - aten nll_loss_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor, AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index +def aten_ReciprocalOp: aten_Op<"reciprocal", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::reciprocal"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten nll_loss2d_forward operator"; - let description = [{ - NllLoss2dForwardOp - aten nll_loss2d_forward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index, - AnyTensor:$total_weight +def aten_RoundOp: aten_Op<"round", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::round"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten nll_loss2d_backward operator"; - let description = [{ - NllLoss2dBackwardOp - aten nll_loss2d_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$min_val, - AnyScalar:$max_val +def aten_RsqrtOp: aten_Op<"rsqrt", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::rsqrt"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten hardtanh operator"; - let description = [{ - HardtanhOp - aten hardtanh operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_HardtanhBackwardOp: aten_Op<"hardtanh_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self, - AnyScalar:$min_val, - AnyScalar:$max_val +def aten_SigmoidOp: aten_Op<"sigmoid", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::sigmoid"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten hardtanh_backward operator"; - let description = [{ - HardtanhBackwardOp - aten hardtanh_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_HardtanhUnderOp: aten_Op<"hardtanh_", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$min_val, - AnyScalar:$max_val +def aten_SignOp: aten_Op<"sign", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::sign"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten hardtanh_ operator"; - let description = [{ - HardtanhUnderOp - aten hardtanh_ operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_AdaptiveAvgPool2dOp: aten_Op<"_adaptive_avg_pool2d", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyType:$output_size +def aten_SinOp: aten_Op<"sin", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::sin"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten _adaptive_avg_pool2d operator"; - let description = [{ - AdaptiveAvgPool2dOp - aten _adaptive_avg_pool2d operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self +def aten_SinhOp: aten_Op<"sinh", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::sinh"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten _adaptive_avg_pool2d_backward operator"; - let description = [{ - AdaptiveAvgPool2dBackwardOp - aten _adaptive_avg_pool2d_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_MaxPool2dWithIndicesOp: aten_Op<"max_pool2d_with_indices", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor, AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyType:$kernel_size, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation, - AnyScalar:$ceil_mode +def aten_SqrtOp: aten_Op<"sqrt", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::sqrt"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten max_pool2d_with_indices operator"; - let description = [{ - MaxPool2dWithIndicesOp - aten max_pool2d_with_indices operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -def aten_MaxPool2dWithIndicesBackwardOp: aten_Op<"max_pool2d_with_indices_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self, - AnyType:$kernel_size, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation, - AnyScalar:$ceil_mode, - AnyTensor:$indices +def aten_TanOp: aten_Op<"tan", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::tan"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_TanhOp: aten_Op<"tanh", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::tanh"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::trunc"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor ); - let summary = "aten max_pool2d_with_indices_backward operator"; - let description = [{ - MaxPool2dWithIndicesBackwardOp - aten max_pool2d_with_indices_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; } -#endif // NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS diff --git a/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td new file mode 100644 index 000000000..0e181f938 --- /dev/null +++ b/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td @@ -0,0 +1,654 @@ +// This file is (mostly) automatically generated. Please do not edit. +//===- ATenOps.td ------------------------------------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS +#define NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS + +def aten_AddUnderOp: aten_Op<"add_", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$other, + AnyScalar:$alpha + ); + let summary = "aten add_ operator"; + let description = [{ + AddUnderOp + aten add_ operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyType:$size, + AnyType:$stride + ); + let summary = "aten as_strided operator"; + let description = [{ + AsStridedOp + aten as_strided operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$input, + AnyTensor:$weight, + AnyTensor:$bias, + AnyType:$stride, + AnyType:$padding, + AnyType:$dilation, + AnyScalar:$transposed, + AnyType:$output_padding, + AnyScalar:$groups + ); + let summary = "aten convolution_overrideable operator"; + let description = [{ + ConvolutionOverrideableOp + aten convolution_overrideable operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor, AnyTensor, AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$input, + AnyTensor:$weight, + AnyType:$stride, + AnyType:$padding, + AnyType:$dilation, + AnyScalar:$transposed, + AnyType:$output_padding, + AnyScalar:$groups + ); + let summary = "aten convolution_backward_overrideable operator"; + let description = [{ + ConvolutionBackwardOverrideableOp + aten convolution_backward_overrideable operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$other + ); + let summary = "aten div_ operator"; + let description = [{ + DivUnderOp + aten div_ operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyType:$size, + AnyScalar:$implicit + ); + let summary = "aten expand operator"; + let description = [{ + ExpandOp + aten expand operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyScalar:$dim, + AnyScalar:$half_to_float + ); + let summary = "aten _log_softmax operator"; + let description = [{ + LogSoftmaxOp + aten _log_softmax operator + }]; +} + +def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$output, + AnyScalar:$dim, + AnyTensor:$self + ); + let summary = "aten _log_softmax_backward_data operator"; + let description = [{ + LogSoftmaxBackwardDataOp + aten _log_softmax_backward_data operator + }]; +} + +def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self + ); + let summary = "aten mean operator"; + let description = [{ + MeanOp + aten mean operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_MmOp: aten_Op<"mm", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$mat2 + ); + let summary = "aten mm operator"; + let description = [{ + MmOp + aten mm operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_MulUnderOp: aten_Op<"mul_", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$other + ); + let summary = "aten mul_ operator"; + let description = [{ + MulUnderOp + aten mul_ operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_NativeBatchNormOp: aten_Op<"native_batch_norm", [NoSideEffect, StatisticsOpInterface]>, + // FIXME: not quite automatically generated names + Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> { + let arguments = ( + ins AnyTensor:$input, + AnyTensor:$weight, + AnyTensor:$bias, + AnyTensor:$running_mean, + AnyTensor:$running_var, + AnyScalar:$training, + AnyScalar:$momentum, + AnyScalar:$eps + ); + let summary = "aten native_batch_norm operator"; + let description = [{ + NativeBatchNormOp + aten native_batch_norm operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_NativeBatchNormBackwardOp: aten_Op<"native_batch_norm_backward", [NoSideEffect, StatisticsOpInterface]>, + // FIXME: not quite automatically generated + Results<(outs AnyTensor:$dx, AnyTensor:$dm, AnyTensor:$dv)> { +let arguments = ( + ins AnyTensor:$grad_out, + AnyTensor:$input, + AnyTensor:$weight, + AnyTensor:$running_mean, + AnyTensor:$running_var, + AnyTensor:$save_mean, + AnyTensor:$save_invstd, + AnyScalar:$train, + AnyScalar:$eps + ); + let summary = "aten native_batch_norm_backward operator"; + let description = [{ + NativeBatchNormBackwardOp + aten native_batch_norm_backward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self + ); + let summary = "aten relu_ operator"; + let description = [{ + ReluUnderOp + aten relu_ operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyScalar:$dim + ); + let summary = "aten size operator"; + let description = [{ + SizeOp + aten size operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyScalar:$dim + ); + let summary = "aten squeeze operator"; + let description = [{ + SqueezeOp + aten squeeze operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_SumOp: aten_Op<"sum", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyType:$dim, + AnyScalar:$keepdim + ); + let summary = "aten sum operator"; + let description = [{ + SumOp + aten sum operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_TOp: aten_Op<"t", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self + ); + let summary = "aten t operator"; + let description = [{ + TOp + aten t operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_ThresholdBackwardOp: aten_Op<"threshold_backward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$self, + AnyScalar:$threshold + ); + let summary = "aten threshold_backward operator"; + let description = [{ + ThresholdBackwardOp + aten threshold_backward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_UnsqueezeOp: aten_Op<"unsqueeze", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyScalar:$dim + ); + let summary = "aten unsqueeze operator"; + let description = [{ + UnsqueezeOp + aten unsqueeze operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_SubOp: aten_Op<"sub", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$other, + AnyScalar:$alpha + ); + let summary = "aten sub operator"; + let description = [{ + SubOp + aten sub operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_SubUnderOp: aten_Op<"sub_", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$other, + AnyScalar:$alpha + ); + let summary = "aten sub_ operator"; + let description = [{ + SubUnderOp + aten sub_ operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_AddmmOp: aten_Op<"addmm", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$mat1, + AnyTensor:$mat2, + AnyScalar:$beta, + AnyScalar:$alpha + ); + let summary = "aten addmm operator"; + let description = [{ + AddmmOp + aten addmm operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_ViewOp: aten_Op<"view", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyType:$size + ); + let summary = "aten view operator"; + let description = [{ + ViewOp + aten view operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyScalar:$dim, + AnyTensor:$index, + AnyScalar:$sparse_grad + ); + let summary = "aten gather operator"; + let description = [{ + GatherOp + aten gather operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor, AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$target, + AnyTensor:$weight, + AnyScalar:$reduction, + AnyScalar:$ignore_index + ); + let summary = "aten nll_loss_forward operator"; + let description = [{ + NllLossForwardOp + aten nll_loss_forward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$self, + AnyTensor:$target, + AnyTensor:$weight, + AnyScalar:$reduction, + AnyScalar:$ignore_index, + AnyTensor:$total_weight + ); + let summary = "aten nll_loss_backward operator"; + let description = [{ + NllLossBackwardOp + aten nll_loss_backward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor, AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyTensor:$target, + AnyTensor:$weight, + AnyScalar:$reduction, + AnyScalar:$ignore_index + ); + let summary = "aten nll_loss2d_forward operator"; + let description = [{ + NllLoss2dForwardOp + aten nll_loss2d_forward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$self, + AnyTensor:$target, + AnyTensor:$weight, + AnyScalar:$reduction, + AnyScalar:$ignore_index, + AnyTensor:$total_weight + ); + let summary = "aten nll_loss2d_backward operator"; + let description = [{ + NllLoss2dBackwardOp + aten nll_loss2d_backward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyScalar:$min_val, + AnyScalar:$max_val + ); + let summary = "aten hardtanh operator"; + let description = [{ + HardtanhOp + aten hardtanh operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_HardtanhBackwardOp: aten_Op<"hardtanh_backward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$self, + AnyScalar:$min_val, + AnyScalar:$max_val + ); + let summary = "aten hardtanh_backward operator"; + let description = [{ + HardtanhBackwardOp + aten hardtanh_backward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_HardtanhUnderOp: aten_Op<"hardtanh_", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyScalar:$min_val, + AnyScalar:$max_val + ); + let summary = "aten hardtanh_ operator"; + let description = [{ + HardtanhUnderOp + aten hardtanh_ operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_AdaptiveAvgPool2dOp: aten_Op<"_adaptive_avg_pool2d", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyType:$output_size + ); + let summary = "aten _adaptive_avg_pool2d operator"; + let description = [{ + AdaptiveAvgPool2dOp + aten _adaptive_avg_pool2d operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$self + ); + let summary = "aten _adaptive_avg_pool2d_backward operator"; + let description = [{ + AdaptiveAvgPool2dBackwardOp + aten _adaptive_avg_pool2d_backward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_MaxPool2dWithIndicesOp: aten_Op<"max_pool2d_with_indices", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor, AnyTensor)> { + let arguments = ( + ins AnyTensor:$self, + AnyType:$kernel_size, + AnyType:$stride, + AnyType:$padding, + AnyType:$dilation, + AnyScalar:$ceil_mode + ); + let summary = "aten max_pool2d_with_indices operator"; + let description = [{ + MaxPool2dWithIndicesOp + aten max_pool2d_with_indices operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +def aten_MaxPool2dWithIndicesBackwardOp: aten_Op<"max_pool2d_with_indices_backward", [NoSideEffect, StatisticsOpInterface]>, + Results<(outs AnyTensor)> { + let arguments = ( + ins AnyTensor:$grad_output, + AnyTensor:$self, + AnyType:$kernel_size, + AnyType:$stride, + AnyType:$padding, + AnyType:$dilation, + AnyScalar:$ceil_mode, + AnyTensor:$indices + ); + let summary = "aten max_pool2d_with_indices_backward operator"; + let description = [{ + MaxPool2dWithIndicesBackwardOp + aten max_pool2d_with_indices_backward operator + }]; + let extraClassDeclaration = [{ + std::map getStatistics(); + }]; +} + +#endif // NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS diff --git a/include/npcomp/Dialect/Torch/IR/OpInterfaces.h b/include/npcomp/Dialect/Torch/IR/OpInterfaces.h index 4e968e385..0d338e2ee 100644 --- a/include/npcomp/Dialect/Torch/IR/OpInterfaces.h +++ b/include/npcomp/Dialect/Torch/IR/OpInterfaces.h @@ -28,7 +28,11 @@ enum BitMask { // Coerce/require a mutable tensor value. kMutableTensor = 4, - LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kMutableTensor) + // If the source is a Scalar and the target is a Tensor, promotes + // to a 0d tensor. + kPromoteScalar = 8, + + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar) }; LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); } // namespace KernelValueConversion diff --git a/lib/Dialect/ATen/IR/ATenDialect.cpp b/lib/Dialect/ATen/IR/ATenDialect.cpp index 50f76a7d7..87486f1fa 100644 --- a/lib/Dialect/ATen/IR/ATenDialect.cpp +++ b/lib/Dialect/ATen/IR/ATenDialect.cpp @@ -108,3 +108,5 @@ void ATenDialect::initialize() { #include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc" #include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc" + +#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc" diff --git a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp index 961f611cb..cbf4d1348 100644 --- a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp +++ b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp @@ -55,33 +55,6 @@ std::map AdaptiveAvgPool2dBackwardOp::getStatistics() { return toReturn; } -// add -std::map AddOp::getStatistics() { - - std::map toReturn; - - TensorType resultTy = getResult().getType().cast(); - TensorType aType = getOperand(0).getType().cast(); - Type bType = getOperand(1).getType(); - - uint64_t ofm_volume = getTensorVolume(resultTy); - - toReturn["ops:+"] = ofm_volume; - toReturn["result:0:activation_out"] = ofm_volume; - - // Find the size of the A and B operands - uint64_t a_volume = getTensorVolume(aType); - uint64_t b_volume = getTensorVolume(bType); - - toReturn["operand:0:activation_in"] = a_volume; - toReturn["operand:1:activation_in"] = b_volume; - - toReturn["reads"] = a_volume + b_volume; - toReturn["writes"] = ofm_volume; - - return toReturn; -} - // add_ std::map AddUnderOp::getStatistics() { @@ -231,33 +204,6 @@ ConvolutionBackwardOverrideableOp::getStatistics() { return getConv2dBackwardStatistics(*this, groups); } -// div -std::map DivOp::getStatistics() { - - std::map toReturn; - - TensorType resultTy = getResult().getType().cast(); - TensorType aType = getOperand(0).getType().cast(); - Type bType = getOperand(1).getType(); - - uint64_t ofm_volume = getTensorVolume(resultTy); - toReturn["ops:/"] = ofm_volume; - - toReturn["result:0:activation_out"] = ofm_volume; - - // Find the size of the A and B operands - uint64_t a_volume = getTensorVolume(aType); - uint64_t b_volume = getTensorVolume(bType); - - toReturn["operand:0:activation_in"] = a_volume; - toReturn["operand:1:activation_in"] = b_volume; - - toReturn["reads"] = a_volume + b_volume; - toReturn["writes"] = ofm_volume; - - return toReturn; -} - // div_ std::map DivUnderOp::getStatistics() { @@ -467,32 +413,6 @@ std::map MmOp::getStatistics() { return getMMOpStatistics(*this); } -// mul -std::map MulOp::getStatistics() { - - std::map toReturn; - - TensorType resultTy = getResult().getType().cast(); - TensorType aType = getOperand(0).getType().cast(); - Type bType = getOperand(1).getType(); - - uint64_t ofm_volume = getTensorVolume(resultTy); - toReturn["ops:*"] = ofm_volume; - toReturn["result:0:activation_out"] = ofm_volume; - - // Find the size of the A and B operands - uint64_t a_volume = getTensorVolume(aType); - uint64_t b_volume = getTensorVolume(bType); - - toReturn["operand:0:activation_in"] = a_volume; - toReturn["operand:1:activation_in"] = b_volume; - - toReturn["reads"] = a_volume + b_volume; - toReturn["writes"] = ofm_volume; - - return toReturn; -} - // mul_ std::map MulUnderOp::getStatistics() { @@ -668,23 +588,6 @@ std::map NllLoss2dBackwardOp::getStatistics() { return toReturn; } -// neg op -std::map NegOp::getStatistics() { - std::map toReturn; - auto insize = getTensorVolume(getOperand().getType()); - auto outsize = getTensorVolume(getResult().getType()); - toReturn["reads"] = toReturn["operand:0:activation_in"] = insize; - toReturn["writes"] = toReturn["result:0:activation_out"] = outsize; - return toReturn; -} - -// relu -// std::map ReLUOp::getStatistics() { -// return getReLUOpStatistics(*this); -// } -std::map ReluOp::getStatistics() { - return getReLUOpStatistics(*this); -} // std::map ReLUUnderOp::getStatistics() { // return getReLUOpStatistics(*this); // } diff --git a/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp b/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp index 730c6a6f6..9a450c5b7 100644 --- a/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp +++ b/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp @@ -47,6 +47,7 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType, // Immutable tensor conversion. if (flag & KVC::kImmutableTensor) { + // TODO: Support the kPromoteScalar flag. if (sourceTorchType != "Tensor" || targetTorchType != "Tensor") return None; diff --git a/test/Dialect/ATen/aten_add.mlir b/test/Dialect/ATen/aten_add.mlir deleted file mode 100644 index 0d2f3e0c8..000000000 --- a/test/Dialect/ATen/aten_add.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s -// CHECK-LABEL: "L0-add-0": { -// CHECK-NEXT: "activation_in": 12, -// CHECK-NEXT: "activation_out": 6, -// CHECK-NEXT: "ops:+": 6, -// CHECK-NEXT: "reads": 12, -// CHECK-NEXT: "writes": 6 - -// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION -// CHECK-CONVERSION-LABEL: @graph -func @graph(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - %1 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %2 = "aten.add"(%arg0, %arg1, %1) : (tensor<1x2x3xf32>, tensor<1x2x3xf32>, i32) -> tensor<1x2x3xf32> - "std.return"(%2) : (tensor<1x2x3xf32>) -> () -} diff --git a/test/Dialect/ATen/aten_relu.mlir b/test/Dialect/ATen/aten_relu.mlir deleted file mode 100644 index 66bcc7276..000000000 --- a/test/Dialect/ATen/aten_relu.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s -// CHECK-LABEL: "L0-relu-0": { -// CHECK-NEXT: "activation_in": 6, -// CHECK-NEXT: "activation_out": 6, -// CHECK-NEXT: "ops:>": 6, -// CHECK-NEXT: "reads": 6, -// CHECK-NEXT: "writes": 6 - -module { - func @graph(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - %0 = "aten.relu"(%arg0) : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> - "std.return"(%0) : (tensor<1x2x3xf32>) -> () - } -} - diff --git a/test/Dialect/ATen/aten_resA.mlir b/test/Dialect/ATen/aten_resA.mlir deleted file mode 100644 index 074a904b7..000000000 --- a/test/Dialect/ATen/aten_resA.mlir +++ /dev/null @@ -1,61 +0,0 @@ -// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s -// CHECK-LABEL: "L0-native_batch_norm-0": { -// CHECK-LABEL: "L1-relu-0": { -// CHECK-LABEL: "L2-_convolution-0": { -// CHECK-LABEL: "L3-native_batch_norm-1": { -// CHECK-LABEL: "L4-relu-1": { -// CHECK-LABEL: "L5-_convolution-1": { -// CHECK-LABEL: "L6-native_batch_norm-2": { -// CHECK-LABEL: "L7-relu-2": { -// CHECK-LABEL: "L8-_convolution-2": { -// CHECK-LABEL: "L9-add-0": { - -module { - func @graph(%arg0: tensor<1x16x128x128xf32>, %arg1: tensor<1x16x128x128xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>, %arg5: tensor<16xf32>, %arg6: tensor<8x16x1x1xf32>, %arg7: tensor<8xf32>, %arg8: tensor<8xf32>, %arg9: tensor<8xf32>, %arg10: tensor<8xf32>, %arg11: tensor<8xf32>, %arg12: tensor<8x8x3x3xf32>, %arg13: tensor<8xf32>, %arg14: tensor<8xf32>, %arg15: tensor<8xf32>, %arg16: tensor<8xf32>, %arg17: tensor<8xf32>, %arg18: tensor<16x8x1x1xf32>, %arg19: tensor<16xf32>) -> tensor<1x16x128x128xf32> { - %0 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %1 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32 - %2 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32 - %3:3 = "aten.native_batch_norm"(%arg1, %arg2, %arg3, %arg4, %arg5, %0, %1, %2) : (tensor<1x16x128x128xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i1, f32, f32) -> (tensor<1x16x128x128xf32>, tensor<16xf32>, tensor<16xf32>) - %4 = "aten.relu"(%3#0) : (tensor<1x16x128x128xf32>) -> tensor<1x16x128x128xf32> - %5 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %6 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list - %7 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %8 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %9 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list - %10 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %11 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %12 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %13 = "aten._convolution"(%4, %arg6, %arg7, %5, %6, %7) : (tensor<1x16x128x128xf32>, tensor<8x16x1x1xf32>, tensor<8xf32>, !aten.list, !aten.list, !aten.list) -> tensor<1x8x128x128xf32> - %14 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %15 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32 - %16 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32 - %17:3 = "aten.native_batch_norm"(%13, %arg8, %arg9, %arg10, %arg11, %14, %15, %16) : (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, i1, f32, f32) -> (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>) - %18 = "aten.relu"(%17#0) : (tensor<1x8x128x128xf32>) -> tensor<1x8x128x128xf32> - %19 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %20 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %21 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %22 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %23 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list - %24 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %25 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %26 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %27 = "aten._convolution"(%18, %arg12, %arg13, %19, %20, %21) : (tensor<1x8x128x128xf32>, tensor<8x8x3x3xf32>, tensor<8xf32>, !aten.list, !aten.list, !aten.list) -> tensor<1x8x128x128xf32> - %28 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %29 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32 - %30 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32 - %31:3 = "aten.native_batch_norm"(%27, %arg14, %arg15, %arg16, %arg17, %28, %29, %30) : (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, i1, f32, f32) -> (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>) - %32 = "aten.relu"(%31#0) : (tensor<1x8x128x128xf32>) -> tensor<1x8x128x128xf32> - %33 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %34 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list - %35 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %36 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %37 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list - %38 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %39 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %40 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %41 = "aten._convolution"(%32, %arg18, %arg19, %33, %34, %35) : (tensor<1x8x128x128xf32>, tensor<16x8x1x1xf32>, tensor<16xf32>, !aten.list, !aten.list, !aten.list) -> tensor<1x16x128x128xf32> - %42 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %43 = "aten.add"(%arg0, %41, %42) : (tensor<1x16x128x128xf32>, tensor<1x16x128x128xf32>, i32) -> tensor<1x16x128x128xf32> - return %43 : tensor<1x16x128x128xf32> - } -}