mirror of https://github.com/llvm/torch-mlir
Rewrite ATen ODS code generator to be based on new op registry and new signature recognition system.
* Deletes prior code generator from previous attempt (moved some of it into this one). * Renames old generated tablegen source to "Legacy". * Generates ODS and import rules for most binary and unary arithmetic ops. * Removes old generated ops and integration tests that were testing details of the prior setup.pull/98/head
parent
94ea6f7c92
commit
c08935a418
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
# Updates the ATen dialect generated code from the PyTorch op registry.
|
||||
# Requires that the project has been built and that PyTorch support is enabled.
|
||||
set -e
|
||||
|
||||
src_dir="$(realpath $(dirname $0)/..)"
|
||||
build_dir="$(realpath "${NPCOMP_BUILD_DIR:-$src_dir/build}")"
|
||||
aten_dir="${src_dir}/include/npcomp/Dialect/ATen/IR"
|
||||
|
||||
export PYTHONPATH="${build_dir}/python"
|
||||
|
||||
python -m torch_mlir_utils.codegen.torch_signature_ods_gen \
|
||||
--ods_td_file="${aten_dir}/GeneratedATenOps.td" \
|
||||
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc"
|
|
@ -1,202 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Populates an op registry for ATen ops.
|
||||
|
||||
Typically callers will import and use the 'populate' function to add known
|
||||
ops to the OpRegistry. When run interactively as a main module, it simply
|
||||
prints all registered ops.
|
||||
"""
|
||||
|
||||
from .registry import *
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def populate(r: OpRegistry):
|
||||
# Unary pointwise ops (ordinary that take out refs).
|
||||
for f in [
|
||||
torch.abs, torch.acos, torch.angle, torch.asin, torch.atan, torch.ceil,
|
||||
torch.conj, torch.cos, torch.cosh, torch.digamma, torch.erf, torch.erfc,
|
||||
torch.erfinv, torch.exp, torch.expm1, torch.floor, torch.frac,
|
||||
torch.lgamma, torch.log, torch.log10, torch.log1p, torch.log2, torch.neg,
|
||||
torch.reciprocal, torch.round, torch.rsqrt, torch.sigmoid, torch.sign,
|
||||
torch.sin, torch.sinh, torch.sqrt, torch.tan, torch.tanh, torch.trunc
|
||||
]:
|
||||
r.op(f, TensorValue("input")).with_outref_variant()
|
||||
|
||||
# Unary pointwise ops (that do not take out refs).
|
||||
for f in [torch.relu]:
|
||||
r.op(f, TensorValue("input"))
|
||||
|
||||
# Binary pointwise ops.
|
||||
r.op(torch.add,
|
||||
TensorValue("input"),
|
||||
TensorValue("other"),
|
||||
alpha=LiteralValue()).with_outref_variant()
|
||||
r.op(torch.atan2, TensorValue("input"),
|
||||
TensorValue("other")).with_outref_variant()
|
||||
r.op(torch.div, TensorValue("input"),
|
||||
TensorValue("other")).with_outref_variant()
|
||||
r.op(torch.floor_divide, TensorValue("input"),
|
||||
TensorValue("other")).with_outref_variant()
|
||||
r.op(torch.mul, TensorValue("input"),
|
||||
TensorValue("other")).with_outref_variant()
|
||||
r.op(torch.remainder, TensorValue("input"),
|
||||
TensorValue("other")).with_outref_variant()
|
||||
r.op(torch.true_divide, TensorValue("dividend"),
|
||||
TensorValue("divisor")).with_outref_variant()
|
||||
|
||||
# Aggregate operations.
|
||||
# TODO: Support the optional dtype= parameter.
|
||||
r.op(torch.cumsum, TensorValue("input", example_size=(10, 3)),
|
||||
LiteralValue("dim", value=1)).with_outref_variant()
|
||||
r.op(
|
||||
torch.mean,
|
||||
TensorValue("input", example_size=(10, 3)), #
|
||||
LiteralValue("dim", value=[0], mlir_ods_predicate="ATen_IntList"), #
|
||||
LiteralValue("keep_dim",
|
||||
value=False,
|
||||
mlir_ods_predicate="ATen_BoolScalar") #
|
||||
).with_outref_variant()
|
||||
r.op(
|
||||
torch.sum,
|
||||
TensorValue("input", example_size=(10, 3)), #
|
||||
LiteralValue("dim", value=[0], mlir_ods_predicate="ATen_IntList"), #
|
||||
LiteralValue("keep_dim",
|
||||
value=False,
|
||||
mlir_ods_predicate="ATen_BoolScalar") #
|
||||
).with_outref_variant()
|
||||
|
||||
# Gather.
|
||||
(r.op(
|
||||
torch.gather, #
|
||||
TensorValue("input", example_size=(2, 2)), #
|
||||
dim=LiteralValue(value=1, mlir_ods_predicate="ATen_IntScalar"), #
|
||||
index=LiteralValue(value=torch.tensor([[0, 0], [1, 0]]),
|
||||
mlir_ods_predicate="ATen_AnyTensor"), #
|
||||
sparse_grad=LiteralValue(value=False,
|
||||
mlir_ods_predicate="ATen_BoolScalar") #
|
||||
).with_outref_variant())
|
||||
|
||||
# Non-view methods on Tensor.
|
||||
(r.op((TensorValue(name="input", example_size=(3, 1)), "@T")))
|
||||
|
||||
# BLAS and LAPACK ops.
|
||||
r.op(torch.addmm,
|
||||
TensorValue("input", example_size=(2, 3)),
|
||||
TensorValue("mat1", example_size=(2, 3)),
|
||||
TensorValue("mat2", example_size=(3, 3)),
|
||||
beta=LiteralValue(),
|
||||
alpha=LiteralValue()).with_outref_variant()
|
||||
r.op(torch.dot, TensorValue("input", example_size=(10,)),
|
||||
TensorValue("tensor", example_size=(10,)))
|
||||
r.op(torch.matmul, TensorValue("input", example_size=(10, 3, 4)),
|
||||
TensorValue("other", example_size=(4, 5))).with_outref_variant()
|
||||
r.op(torch.mm, TensorValue("input", example_size=(3, 4)),
|
||||
TensorValue("mat2", example_size=(4, 6))).with_outref_variant()
|
||||
|
||||
# NN Functional.
|
||||
# Note that _convolution is a special case and is manually coded.
|
||||
r.op(
|
||||
F.hardtanh,
|
||||
TensorValue("input"), #
|
||||
min_val=LiteralValue(value=-1.0,
|
||||
mlir_ods_predicate="ATen_FloatScalar"), #
|
||||
max_val=LiteralValue(value=1.0, mlir_ods_predicate="ATen_FloatScalar") #
|
||||
)
|
||||
r.op(F.avg_pool1d,
|
||||
TensorValue("input", example_size=(1, 1, 7)),
|
||||
kernel_size=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"),
|
||||
stride=LiteralValue(value=[5], mlir_ods_predicate="ATen_IntList"),
|
||||
padding=LiteralValue(value=[1], mlir_ods_predicate="ATen_IntList"),
|
||||
ceil_mode=LiteralValue(value=True, mlir_ods_predicate="ATen_BoolScalar"),
|
||||
count_include_pad=LiteralValue(value=False,
|
||||
mlir_ods_predicate="ATen_BoolScalar"))
|
||||
|
||||
# MaxPool1D is split into two ops based on whether return_indices is True:
|
||||
# aten::max_pool1d -> tensor<float>
|
||||
# aten::max_pool1d_with_indices -> (tensor<float>, tensor<long>)
|
||||
# Both have odd signatures and are hand-mapped.
|
||||
# TODO: Implement max_pool1d(..., with_indices=True)
|
||||
(r.op(F.max_pool1d,
|
||||
TensorValue("input", example_size=(1, 1, 7)),
|
||||
kernel_size=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"),
|
||||
stride=LiteralValue(value=[5], mlir_ods_predicate="ATen_IntList"),
|
||||
padding=LiteralValue(value=[1], mlir_ods_predicate="ATen_IntList"),
|
||||
dilation=LiteralValue(value=[3], mlir_ods_predicate="ATen_IntList"),
|
||||
ceil_mode=LiteralValue(value=True,
|
||||
mlir_ods_predicate="ATen_BoolScalar"),
|
||||
return_indices=LiteralValue(value=False)) #
|
||||
.with_torch_op_kind("aten::max_pool1d") #
|
||||
.with_operand_map("input", "kernel_size", "stride", "padding", "dilation",
|
||||
"ceil_mode") #
|
||||
)
|
||||
|
||||
# View ops.
|
||||
# TODO: All of these need special analysis and should be parameterized
|
||||
# on mutable tensors and have a proper design thought through. For now,
|
||||
# even having them in the inventory (badly) increases visibility.
|
||||
(r.op(torch.as_strided,
|
||||
TensorValue("input"),
|
||||
size=LiteralValue(value=[2, 2], mlir_ods_predicate="ATen_IntList"),
|
||||
stride=LiteralValue(value=[1, 2], mlir_ods_predicate="ATen_IntList"),
|
||||
storage_offset=LiteralValue(value=4,
|
||||
mlir_ods_predicate="ATen_IntScalar")) #
|
||||
.with_append_description(r"""
|
||||
|
||||
MLIR Specific Notes
|
||||
-------------------
|
||||
In PyTorch proper, this op creates a view that may internally alias. And
|
||||
have explicit warnings about avoiding inplace updates on such a
|
||||
view (without first cloning). For the moment, this op is formulated with
|
||||
value semantics that imply a copy instead of a view, and it is expected
|
||||
that any sharing can be recovered later by the compiler. The warning
|
||||
about not in-place updating of such a result should be treated as UB
|
||||
when compiled.
|
||||
"""))
|
||||
|
||||
(r.op((TensorValue(name="input", example_size=(3, 1)), "expand"),
|
||||
LiteralValue("sizes", value=torch.Size([3, 4]))) #
|
||||
.with_operand_map("input", "sizes",
|
||||
LiteralValue("implicit",
|
||||
mlir_ods_predicate="ATen_BoolScalar")) #
|
||||
.with_append_description(r"""
|
||||
|
||||
MLIR Specific Notes
|
||||
-------------------
|
||||
See notes for the 'as_strided' op.
|
||||
"""))
|
||||
|
||||
(r.op(
|
||||
torch.squeeze,
|
||||
TensorValue("input"), #
|
||||
LiteralValue("dim", value=1, mlir_ods_predicate="ATen_IntScalar") #
|
||||
).with_append_description(r"""
|
||||
|
||||
MLIR Specific Notes
|
||||
-------------------
|
||||
See notes for the 'as_strided' op.
|
||||
"""))
|
||||
|
||||
(r.op((TensorValue(name="input", example_size=(3, 1)), "view"),
|
||||
LiteralValue(name="size",
|
||||
value=[3, 1],
|
||||
mlir_ods_predicate="ATen_IntList"))
|
||||
.with_append_description(r"""
|
||||
|
||||
MLIR Specific Notes
|
||||
-------------------
|
||||
See notes for the 'as_strided' op.
|
||||
"""))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
registry = OpRegistry()
|
||||
populate(registry)
|
||||
print("Registered operations:")
|
||||
for m in registry.mappings:
|
||||
print(" ", m)
|
|
@ -1,205 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Generates ODS for a registry of ops."""
|
||||
|
||||
from typing import TextIO
|
||||
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
import importlib
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from .registry import *
|
||||
|
||||
_INDENT = " "
|
||||
|
||||
|
||||
class OdsEmitter:
|
||||
ods_prefix = "ATen_"
|
||||
ods_suffix = "Op"
|
||||
ods_value_template = "ATen_ImmutableTensorOp"
|
||||
ods_ref_template = "ATen_RefTensorOp"
|
||||
op_prefix = ""
|
||||
|
||||
def __init__(self, r: OpRegistry, out: TextIO):
|
||||
super().__init__()
|
||||
self.r = r
|
||||
self.out = out
|
||||
self.indent_level = 0
|
||||
|
||||
def emit_ods(self):
|
||||
for op_m in self.r.mappings:
|
||||
if isinstance(op_m, SimpleOpMapping):
|
||||
self._emit_simple_op_mapping(op_m)
|
||||
else:
|
||||
logging.warn(f"Unrecognized op mapping type: {op_m!r}")
|
||||
|
||||
def _emit_simple_op_mapping(self, op_m: SimpleOpMapping):
|
||||
identifier = (f"{self.ods_prefix}"
|
||||
f"{_snakecase_to_camelcase(op_m.mlir_operation_name)}"
|
||||
f"{self.ods_suffix}")
|
||||
traits = []
|
||||
|
||||
if op_m.is_outref_form:
|
||||
template_name = self.ods_ref_template
|
||||
summary = "See non-inplace op variant."
|
||||
description = ""
|
||||
else:
|
||||
template_name = self.ods_value_template
|
||||
summary, description = _split_docstring(op_m.op_f.__doc__)
|
||||
|
||||
if not op_m.is_outref_form:
|
||||
traits.append("NoSideEffect")
|
||||
self.print(f"def {identifier}: {template_name}"
|
||||
f"<{_quote(op_m.mlir_operation_name)}, ["
|
||||
f"{', '.join(traits)}"
|
||||
f"]> {{")
|
||||
|
||||
# Summary.
|
||||
with self.indent():
|
||||
self.print(f"let summary = {_quote(summary)};")
|
||||
|
||||
# Arguments.
|
||||
with self.indent():
|
||||
self.print("let arguments = (ins")
|
||||
with self.indent():
|
||||
operand_len = len(op_m.operand_map)
|
||||
for index, (_, value_spec) in enumerate(op_m.operand_map):
|
||||
is_last = index == operand_len - 1
|
||||
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
|
||||
end="\n" if is_last else ",\n")
|
||||
self.print(");")
|
||||
|
||||
# Results (omitted if an outref/inplace form).
|
||||
with self.indent():
|
||||
if op_m.is_outref_form:
|
||||
self.print("let results = (outs);")
|
||||
else:
|
||||
self.print("let results = (outs")
|
||||
with self.indent():
|
||||
result_len = len(op_m.result_map)
|
||||
for index, (_, value_spec) in enumerate(op_m.result_map):
|
||||
is_last = index == result_len - 1
|
||||
self.print(f"{value_spec.mlir_ods_predicate}:${value_spec.name}",
|
||||
end="\n" if is_last else ",\n")
|
||||
self.print(");")
|
||||
|
||||
# Description and extra class declarations.
|
||||
with self.indent():
|
||||
if description:
|
||||
quoted_description = _quote_multiline_docstring(
|
||||
description + op_m.append_description,
|
||||
indent_level=self.indent_level)
|
||||
self.print(f"let description = {quoted_description};")
|
||||
|
||||
self.print("}\n")
|
||||
|
||||
@contextmanager
|
||||
def indent(self, level=1):
|
||||
self.indent_level += level
|
||||
yield
|
||||
self.indent_level -= level
|
||||
assert self.indent_level >= 0, "Unbalanced indentation"
|
||||
|
||||
def print(self, s, *, end="\n", indent=True):
|
||||
if indent and self.indent_level:
|
||||
self.out.write(_INDENT * self.indent_level)
|
||||
self.out.write(s)
|
||||
self.out.write(end)
|
||||
|
||||
|
||||
def _snakecase_to_camelcase(ident: str):
|
||||
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
||||
|
||||
|
||||
def _quote(s: str):
|
||||
s = s.replace(r'"', r'\\"')
|
||||
return f'"{s}"'
|
||||
|
||||
|
||||
def _quote_multiline_docstring(s: str, indent_level: int = 0):
|
||||
# TODO: Possibly find a python module to markdown the docstring for better
|
||||
# document generation.
|
||||
# Unlikely to contain the delimitter and since just a docstring, be safe.
|
||||
s = s.replace("}]", "")
|
||||
# Strip each line.
|
||||
s = "\n".join([l.rstrip() for l in s.splitlines()])
|
||||
indent = _INDENT * indent_level
|
||||
s = textwrap.indent(s, indent + _INDENT)
|
||||
return "[{\n" + s + "\n" + indent + "}]"
|
||||
|
||||
|
||||
def _split_docstring(docstring: str):
|
||||
"""Splits the docstring into a summary and description."""
|
||||
if not docstring:
|
||||
docstring = ""
|
||||
lines = docstring.splitlines()
|
||||
if not lines:
|
||||
return "", ""
|
||||
|
||||
# Skip leading blank lines.
|
||||
while lines and not lines[0]:
|
||||
lines = lines[1:]
|
||||
if len(lines) > 2:
|
||||
return lines[0], "\n".join(lines[2:])
|
||||
else:
|
||||
return lines[0], ""
|
||||
|
||||
|
||||
def main(args):
|
||||
r = OpRegistry()
|
||||
# Populate from modules that provide a populate() function.
|
||||
op_modules = [args.op_module]
|
||||
for m_name in op_modules:
|
||||
logging.info(f"Populating from module: {m_name}")
|
||||
m = importlib.import_module(m_name, package=__package__)
|
||||
f = getattr(m, "populate")
|
||||
f(r)
|
||||
|
||||
out = sys.stdout
|
||||
|
||||
# Write file header.
|
||||
module_name = sys.modules["__main__"].__loader__.name
|
||||
banner_lines = [
|
||||
"//===-------------------------------------------------------*- tablegen -*-===//",
|
||||
"//",
|
||||
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
|
||||
"// See https://llvm.org/LICENSE.txt for license information.",
|
||||
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
|
||||
"//",
|
||||
"// Operation summaries and descriptions were systematically derived from public",
|
||||
"// API docstrings and are licensed accordingly:",
|
||||
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"// This file is automatically generated. Please do not edit.",
|
||||
"// Generated via:",
|
||||
f"// python -m {module_name} {' '.join(sys.argv[1:])}",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
banner_lines = [l.strip() for l in banner_lines]
|
||||
out.write("\n".join(banner_lines))
|
||||
|
||||
emitter = OdsEmitter(r, out=out)
|
||||
emitter.emit_ods()
|
||||
|
||||
|
||||
def _create_argparse():
|
||||
parser = argparse.ArgumentParser(prog="generate_ods")
|
||||
parser.add_argument(
|
||||
"--op_module",
|
||||
default=".aten_ops",
|
||||
help="Name of a python module for populating the registry")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _create_argparse()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -1,495 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Base classes and interfaces for mapping ops to MLIR.
|
||||
|
||||
The goal of this facility is to define the majority of op mappings by example,
|
||||
from the actual Python invocation, using the PyTorch tracer to extract its
|
||||
Graph IR, and introspecting that to determine mappings, metadata and types
|
||||
for corresponding MLIR definitions. This is not meant to cover everything,
|
||||
and it is expected that a number of hard ops should be mapped by hand.
|
||||
|
||||
The result of building such an OpRegistry should be a data structure that
|
||||
can be used to generate ODS and tables for doing systematic imports from
|
||||
a PyTorch Graph to corresponding MLIR module.
|
||||
|
||||
Example usage (fully automatic discovery):
|
||||
|
||||
>>> r = OpRegistry()
|
||||
>>> r.op(torch.add,
|
||||
TensorValue("input"),
|
||||
TensorValue("other"),
|
||||
alpha=ScalarValue()).with_outref_variant()
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import logging
|
||||
import random
|
||||
import textwrap
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"SimpleOpMapping",
|
||||
"OpRegistry",
|
||||
"LiteralValue",
|
||||
"TensorOutRef",
|
||||
"TensorValue",
|
||||
]
|
||||
|
||||
|
||||
def _is_same_value(value1, value2):
|
||||
# Tensors are considered via reference equality.
|
||||
value1_tensor = isinstance(value1, torch.Tensor)
|
||||
value2_tensor = isinstance(value2, torch.Tensor)
|
||||
if value1_tensor or value2_tensor:
|
||||
return value1 is value2
|
||||
|
||||
# Everything else is value equality.
|
||||
return value1 == value2
|
||||
|
||||
|
||||
def _extract_immediate(node):
|
||||
"""Extracts an immediate value from a node.
|
||||
|
||||
Supported node types:
|
||||
prim::Constant
|
||||
prim::ListConstruct
|
||||
"""
|
||||
# Constant
|
||||
if node.kind() == "prim::Constant":
|
||||
# Try to extract as different types.
|
||||
try:
|
||||
return node.t("value")
|
||||
except RuntimeError:
|
||||
pass
|
||||
try:
|
||||
return node.f("value")
|
||||
except RuntimeError:
|
||||
pass
|
||||
try:
|
||||
return node.i("value")
|
||||
except RuntimeError:
|
||||
pass
|
||||
try:
|
||||
return node.s("value")
|
||||
except RuntimeError:
|
||||
pass
|
||||
return None
|
||||
# List
|
||||
elif node.kind() == "prim::ListConstruct":
|
||||
return [_extract_immediate(i.node()) for i in node.inputs()]
|
||||
else:
|
||||
raise ValueError("Unrecognized immediate input node: {!r}".format(node))
|
||||
|
||||
|
||||
def _read_attr_callback(instance, attr_name):
|
||||
|
||||
def f():
|
||||
return getattr(instance, attr_name)
|
||||
|
||||
f.__doc__ = getattr(type(instance), attr_name).__doc__
|
||||
return f
|
||||
|
||||
|
||||
class ValueSpec:
|
||||
"""Base class for inputs to operations.
|
||||
|
||||
This binds information about how the input is mapped to the MLIR operation.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, mlir_ods_predicate: str = "AnyType"):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.mlir_ods_predicate = mlir_ods_predicate
|
||||
|
||||
def generate_example(self, index=0):
|
||||
"""Generates an example value."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({!r})".format(self.__class__.__name__, self.name)
|
||||
|
||||
|
||||
class TensorValue(ValueSpec):
|
||||
"""An input that is a tensor."""
|
||||
|
||||
def __init__(self,
|
||||
name=None,
|
||||
*,
|
||||
example_size=None,
|
||||
mlir_ods_predicate="ATen_AnyTensor",
|
||||
**kwargs):
|
||||
super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs)
|
||||
if example_size is None:
|
||||
example_size = (2, 3, 7) # No significance.
|
||||
self.example_size = example_size
|
||||
|
||||
def generate_example(self, index=0):
|
||||
return torch.rand(*self.example_size)
|
||||
|
||||
|
||||
class TensorOutRef(ValueSpec):
|
||||
"""A tensor that is passed by ref as an out parameter."""
|
||||
|
||||
def __init__(self,
|
||||
name=None,
|
||||
*,
|
||||
example_size=None,
|
||||
mlir_ods_predicate="ATen_AnyRefTensor",
|
||||
**kwargs):
|
||||
super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs)
|
||||
if example_size is None:
|
||||
example_size = (2, 3, 7) # No significance.
|
||||
self.example_size = example_size
|
||||
|
||||
def generate_example(self, index=0):
|
||||
return torch.rand(*self.example_size)
|
||||
|
||||
|
||||
class LiteralValue(ValueSpec):
|
||||
"""An input that is a literal value."""
|
||||
|
||||
def __init__(self,
|
||||
name=None,
|
||||
value=None,
|
||||
mlir_ods_predicate="ATen_AnyScalar",
|
||||
**kwargs):
|
||||
super().__init__(name=name, mlir_ods_predicate=mlir_ods_predicate, **kwargs)
|
||||
self.value = value
|
||||
|
||||
def generate_example(self, index=0):
|
||||
if self.value is not None:
|
||||
return self.value
|
||||
return 1.0 + index # Generates a stable value.
|
||||
|
||||
|
||||
class OpMapping:
|
||||
"""Base class for things purporting to map an operation."""
|
||||
__slots__ = []
|
||||
|
||||
|
||||
class SimpleOpMapping(OpMapping):
|
||||
"""Maps a PyTorch invocation to its MLIR representation."""
|
||||
__slots__ = [
|
||||
"append_description",
|
||||
"mlir_operation_name",
|
||||
"operand_map",
|
||||
"op_args",
|
||||
"op_arity",
|
||||
"op_f",
|
||||
"op_kind",
|
||||
"op_kwargs",
|
||||
"op_self",
|
||||
"op_self_value_spec",
|
||||
"outref_variant_value",
|
||||
"result_map",
|
||||
]
|
||||
|
||||
def __init__(self, op_f, *op_args, **op_kwargs):
|
||||
super().__init__()
|
||||
self.op_f = op_f
|
||||
self.op_self = None
|
||||
self.op_self_value_spec = None
|
||||
self.op_args = op_args
|
||||
self.op_kwargs = op_kwargs
|
||||
self.outref_variant_value = None # type: Optional[TensorOutRef]
|
||||
self.append_description = ""
|
||||
|
||||
# Set after finalize.
|
||||
self.op_kind = None # type: Optional[str]
|
||||
self.op_arity = -1 # type: int
|
||||
self.operand_map = None # type: Optional[List[Tuple[int, ValueSpec]]]
|
||||
self.result_map = None # type: Optional[List[Tuple[int, ValueSpec]]]
|
||||
self.mlir_operation_name = None # type: Optional[str]
|
||||
|
||||
# Fixup self-calls.
|
||||
# These are specified as tuples of (ValueSpec, "method_name")
|
||||
if isinstance(self.op_f, tuple):
|
||||
assert len(self.op_f) == 2, "Self-calls must be a 2-tuple"
|
||||
self.op_self_value_spec, method_name = self.op_f
|
||||
self.op_self = self.op_self_value_spec.generate_example(index=-1)
|
||||
if method_name.startswith("@"):
|
||||
# Create a pseudo-function to resolve the attribute.
|
||||
self.op_f = _read_attr_callback(self.op_self, method_name[1:])
|
||||
else:
|
||||
self.op_f = getattr(self.op_self, method_name)
|
||||
|
||||
def __repr__(self):
|
||||
return ("SimpleOp({kind!r}[{arity}] -> {name!s}, operands={operands!r}, "
|
||||
"results={results!r})".format(kind=self.op_kind,
|
||||
arity=self.op_arity,
|
||||
name=self.mlir_operation_name,
|
||||
operands=self.operand_map,
|
||||
results=self.result_map))
|
||||
|
||||
def clone(self) -> "SimpleOpMapping":
|
||||
copy = SimpleOpMapping(self.op_f, *self.op_args, **self.op_kwargs)
|
||||
for name in [
|
||||
"outref_variant_value", "op_kind", "op_arity", "operand_map",
|
||||
"result_map", "mlir_operation_name"
|
||||
]:
|
||||
setattr(copy, name, getattr(self, name))
|
||||
return copy
|
||||
|
||||
def with_outref_variant(self, value=None) -> "SimpleOpMapping":
|
||||
"""Instructs the registry to also generate an outref variant.
|
||||
|
||||
This is done by cloning the op prior to finalizing and adding an out=
|
||||
paramer.
|
||||
"""
|
||||
self.outref_variant_value = TensorOutRef() if value is None else value
|
||||
return self
|
||||
|
||||
def with_torch_op_kind(self, op_kind: str) -> "SimpleOpMapping":
|
||||
"""Uses a manually specified op kind (i.e. "aten::some_op")."""
|
||||
self.op_kind = op_kind
|
||||
return self
|
||||
|
||||
def with_operand_map(self, *mlir_operand_names):
|
||||
"""Manually maps torch IR operands to mlir operand names."""
|
||||
operand_map = []
|
||||
for i, name_or_value_spec in enumerate(mlir_operand_names):
|
||||
if isinstance(name_or_value_spec, ValueSpec):
|
||||
value_spec = name_or_value_spec
|
||||
else:
|
||||
value_spec = self.find_value_spec_by_name(name_or_value_spec)
|
||||
operand_map.append((i, value_spec))
|
||||
self.operand_map = operand_map
|
||||
return self
|
||||
|
||||
def with_append_description(self, description) -> "SimpleOpMapping":
|
||||
"""Appends a description to the ODS."""
|
||||
self.append_description += textwrap.dedent(description)
|
||||
return self
|
||||
|
||||
@property
|
||||
def all_arg_values(self) -> List[ValueSpec]:
|
||||
"""Returns all arg values (either positional or kw)."""
|
||||
args = list(self.op_args) + list(self.op_kwargs.values())
|
||||
if self.op_self_value_spec:
|
||||
args = [self.op_self_value_spec] + args
|
||||
return args
|
||||
|
||||
@property
|
||||
def is_outref_form(self) -> bool:
|
||||
"""Whether the op contains an out parameter that aliases to the result."""
|
||||
return any(isinstance(a, TensorOutRef) for a in self.all_arg_values)
|
||||
|
||||
def find_value_spec_by_name(self, name) -> ValueSpec:
|
||||
"""Finds an argument ValueSpec by name.
|
||||
|
||||
Raises:
|
||||
ValueError if not found.
|
||||
"""
|
||||
if self.op_self_value_spec and self.op_self_value_spec.name == name:
|
||||
return self.op_self_value_spec
|
||||
value_spec = self.op_kwargs.get(name)
|
||||
if value_spec:
|
||||
return value_spec
|
||||
for value_spec in self.op_args:
|
||||
if value_spec.name == name:
|
||||
return value_spec
|
||||
raise ValueError(f"Unknown value spec: {name}")
|
||||
|
||||
def generate_example(self) -> Tuple[Tuple, Dict]:
|
||||
"""Generates an example signature for invoking the op.
|
||||
|
||||
Returns:
|
||||
(tuple, dict) of positional and keyword args.
|
||||
"""
|
||||
index = 0
|
||||
positional = list()
|
||||
kw = dict()
|
||||
for op_arg in self.op_args:
|
||||
positional.append(op_arg.generate_example(index))
|
||||
index += 1
|
||||
for kw_name, kw_value in self.op_kwargs.items():
|
||||
kw[kw_name] = kw_value.generate_example(index)
|
||||
index += 1
|
||||
return positional, kw
|
||||
|
||||
def finalize(self):
|
||||
"""Finalizes the mapping once all hints have been applied."""
|
||||
# Update the name on all args if undefined.
|
||||
for index, op_arg in enumerate(self.op_args):
|
||||
if op_arg.name is None:
|
||||
op_arg.name = "arg%d".format(index)
|
||||
for key, op_arg in self.op_kwargs.items():
|
||||
if op_arg.name is None:
|
||||
op_arg.name = key
|
||||
|
||||
# Create an example graph and configure from it.
|
||||
self._configure_from_example()
|
||||
|
||||
# Determine default operation name.
|
||||
if self.mlir_operation_name is None:
|
||||
self._set_default_mlir_operation_name()
|
||||
|
||||
def _set_default_mlir_operation_name(self):
|
||||
op_ns, op_name = self.op_kind.split("::", maxsplit=1)
|
||||
# Since these are emitted into the "aten" dialect namespace, alias them
|
||||
# to omit the prefix to distinguish from custom ops and others (which will
|
||||
# have a prefix).
|
||||
default_name = op_name if op_ns == "aten" else op_ns + "." + op_name
|
||||
if op_ns == "aten":
|
||||
default_name = op_name
|
||||
else:
|
||||
default_name = op_ns + "." + op_name
|
||||
|
||||
if self.is_outref_form:
|
||||
default_name += ".inplace"
|
||||
self.mlir_operation_name = default_name
|
||||
|
||||
def _configure_from_example(self):
|
||||
# Trace the op so that we get an example graph like this:
|
||||
# %0 : Float(2, 3, 7) = prim::Constant[value=<Tensor>]()
|
||||
# %1 : Float(2, 3, 7) = prim::Constant[value=<Tensor>]()
|
||||
# %2 : float = prim::Constant[value=3.]()
|
||||
# %3 : Float(2, 3, 7) = aten::add(%0, %1, %2)
|
||||
# return (%3)
|
||||
# The def of the return value is expected to be the modeled op. The
|
||||
# inputs to that op are expected to be captured constants that can be
|
||||
# re-associated to the example inputs.
|
||||
example_args, example_kwargs = self.generate_example()
|
||||
|
||||
def forward():
|
||||
# logging.debug(
|
||||
# f"Invoke {self.op_f} with *{example_args}, **{example_kwargs}")
|
||||
return self.op_f(*example_args, **example_kwargs)
|
||||
|
||||
trace = torch.jit.trace(forward, tuple())
|
||||
graph = trace.graph
|
||||
logging.debug("Graph for op %r: %s", self.op_f, graph)
|
||||
|
||||
# Track up from the return node and assume this is our op.
|
||||
return_node = graph.return_node()
|
||||
return_inputs = list(return_node.inputs())
|
||||
assert len(return_inputs) == 1, "Expected one input return"
|
||||
op_node = return_inputs[0].node()
|
||||
op_inputs = list(op_node.inputs())
|
||||
logging.debug("Found op node: %r", op_node)
|
||||
|
||||
# Meta-data about the source op.
|
||||
self.op_kind = op_node.kind()
|
||||
self.op_arity = len(op_inputs)
|
||||
if self.operand_map is None:
|
||||
self.operand_map = self._associate_inputs(op_inputs, example_args,
|
||||
example_kwargs)
|
||||
|
||||
# Results.
|
||||
op_outputs = list(op_node.outputs())
|
||||
if self.result_map is None:
|
||||
if self.is_outref_form:
|
||||
# Only support single outref results.
|
||||
assert len(op_outputs) == 1, (
|
||||
"For outref ops, only a single output is supported")
|
||||
self.result_map = [(0, TensorOutRef("result"))]
|
||||
else:
|
||||
# Map results in order.
|
||||
self.result_map = []
|
||||
|
||||
def result_name(i):
|
||||
if len(op_outputs) == 1:
|
||||
return "result"
|
||||
else:
|
||||
return "result%d" % i
|
||||
|
||||
for i, op_output in enumerate(op_outputs):
|
||||
op_output_type = op_output.type()
|
||||
if issubclass(type(op_output_type), torch.TensorType):
|
||||
self.result_map.append((i, TensorValue(result_name(i))))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported op output type: {!r}".format(op_output_type))
|
||||
return self
|
||||
|
||||
def _associate_inputs(self, op_inputs, example_args, example_kwargs):
|
||||
"""Given inputs to a graph op node, associates to the input args.
|
||||
|
||||
This will match up example arguments with what was produced in the graph,
|
||||
setting the operand_map.
|
||||
|
||||
Returns:
|
||||
List of (input_index, ValueSpec) mapping inputs to the graph node to
|
||||
provided values in the op definition.
|
||||
"""
|
||||
assert len(example_args) == len(self.op_args)
|
||||
assert example_kwargs.keys() == self.op_kwargs.keys()
|
||||
|
||||
def find_arg(value):
|
||||
if self.op_self is not None and _is_same_value(self.op_self, value):
|
||||
return self.op_self_value_spec
|
||||
for i, arg in enumerate(example_args):
|
||||
if _is_same_value(arg, value):
|
||||
return self.op_args[i]
|
||||
for key, arg in example_kwargs.items():
|
||||
if _is_same_value(arg, value):
|
||||
return self.op_kwargs[key]
|
||||
|
||||
raise KeyError("Op input not in arguments: {!r} -> {!r}".format(
|
||||
value, op_inputs))
|
||||
|
||||
operand_map = []
|
||||
for i, op_input in enumerate(op_inputs):
|
||||
input_node = op_input.node()
|
||||
immediate_value = _extract_immediate(input_node)
|
||||
if immediate_value is not None:
|
||||
operand_map.append((i, find_arg(immediate_value)))
|
||||
return operand_map
|
||||
|
||||
|
||||
class OpRegistry:
|
||||
"""Maintains a registry of op mappings."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._mappings = []
|
||||
self._pending_mapping = None
|
||||
|
||||
def op(self, op_f, *op_args, **op_kwargs):
|
||||
"""Forwards to the SimpleOpMapping constructor and adds it.
|
||||
|
||||
The mapping is not finalized until either the registry is finalized or the
|
||||
next op mapping is added. This allows tweaks to the mapping to be done
|
||||
inline prior to performing detailed introspection.
|
||||
|
||||
Returns:
|
||||
The SimpleOpMapping instance.
|
||||
"""
|
||||
self._finalize_pending()
|
||||
m = SimpleOpMapping(op_f, *op_args, **op_kwargs)
|
||||
self._pending_mapping = m
|
||||
return m
|
||||
|
||||
@property
|
||||
def mappings(self) -> Sequence[OpMapping]:
|
||||
"""Returns the list of OpMapping.
|
||||
|
||||
Returns:
|
||||
Sequence of OpMapping concrete classes (most commonly SimpleOpMapping).
|
||||
"""
|
||||
self._finalize_pending()
|
||||
return self._mappings
|
||||
|
||||
def _finalize_pending(self):
|
||||
if not self._pending_mapping:
|
||||
return
|
||||
|
||||
outref_mapping = None
|
||||
pending_mapping = self._pending_mapping
|
||||
self._pending_mapping = None
|
||||
if pending_mapping.outref_variant_value:
|
||||
# Generate a variant (with an out= form).
|
||||
outref_mapping = pending_mapping.clone()
|
||||
outref_mapping.op_kwargs["out"] = outref_mapping.outref_variant_value
|
||||
outref_mapping.outref_variant_value = None
|
||||
|
||||
# Finalize the original.
|
||||
pending_mapping.finalize()
|
||||
self._mappings.append(pending_mapping)
|
||||
|
||||
# Finalize the outref form if generated.
|
||||
if outref_mapping:
|
||||
outref_mapping.finalize()
|
||||
self._mappings.append(outref_mapping)
|
|
@ -0,0 +1,477 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Queries the pytorch op registry and generates ODS and CC sources for the ops.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, TextIO, Sequence, Tuple
|
||||
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
import importlib
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
# Note that this utility exists only in the c-extension.
|
||||
from _torch_mlir import get_registered_ops
|
||||
|
||||
|
||||
def _create_argparse():
|
||||
parser = argparse.ArgumentParser(prog="generate_ods")
|
||||
parser.add_argument("--ods_td_file",
|
||||
required=True,
|
||||
help="File to write the generated ODS to")
|
||||
parser.add_argument("--ods_impl_file",
|
||||
required=True,
|
||||
help="CC include file to include in ops implementation")
|
||||
parser.add_argument("--debug_op_reg_file",
|
||||
help="Write out a file of op registrations")
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
reg_ops = _load_ops_as_dict()
|
||||
if args.debug_op_reg_file:
|
||||
with open(args.debug_op_reg_file, "w") as debug_ops_file:
|
||||
dump_registered_ops(debug_ops_file, reg_ops)
|
||||
|
||||
with open(args.ods_td_file, "w") as ods_file, open(args.ods_impl_file,
|
||||
"w") as impl_file:
|
||||
ods_emitter = OdsEmitter(ods_file)
|
||||
ods_emitter.print(ODS_BANNER)
|
||||
impl_emitter = CCImplEmitter(impl_file)
|
||||
impl_emitter.print(CC_IMPL_BANNER)
|
||||
generator = OpGenerator(reg_ops, ods_emitter, impl_emitter)
|
||||
generate_ops(generator)
|
||||
|
||||
|
||||
def generate_ops(g: "OpGenerator"):
|
||||
# "Binary"-ops.
|
||||
# There are some variation in these, so we spell them out in case if they
|
||||
# need individual customization.
|
||||
g.print_banner("Binary arithmetic ops")
|
||||
g.ordinary_binary_op("aten::add(Tensor,Tensor,Scalar)", "AddOp", "add")
|
||||
g.ordinary_binary_op("aten::atan2(Tensor,Tensor)", "Atan2Op", "atan2")
|
||||
g.ordinary_binary_op("aten::div(Tensor,Tensor)", "DivOp", "div")
|
||||
g.ordinary_binary_op("aten::floor_divide(Tensor,Tensor)", "FloorDivideOp",
|
||||
"floor_divide")
|
||||
g.ordinary_binary_op("aten::mul(Tensor,Tensor)", "MulOp", "mul")
|
||||
g.ordinary_binary_op("aten::remainder(Tensor,Tensor)", "RemainderOp",
|
||||
"remainder")
|
||||
g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp",
|
||||
"true_divide")
|
||||
|
||||
# Unary-ops. These are all the same so just name munge them.
|
||||
g.print_banner("Unary arithmetic ops")
|
||||
for uname in [
|
||||
"abs", "acos", "angle", "asin", "atan", "ceil", "conj", "cos", "cosh",
|
||||
"digamma", "erf", "erfc", "erfinv", "exp", "expm1", "floor", "frac",
|
||||
"lgamma", "log", "log10", "log1p", "log2", "neg", "relu", "reciprocal",
|
||||
"round", "rsqrt", "sigmoid", "sign", "sin", "sinh", "sqrt", "tan", "tanh",
|
||||
"trunc"
|
||||
]:
|
||||
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
|
||||
f"{snakecase_to_camelcase(uname)}Op", uname)
|
||||
|
||||
|
||||
def dump_registered_ops(outfile, reg_ops_dict):
|
||||
for k in sorted(reg_ops_dict.keys()):
|
||||
attr_dict = reg_ops_dict[k]
|
||||
outfile.write(f"OP '{k}':\n")
|
||||
for attr_name, attr_value in attr_dict.items():
|
||||
outfile.write(f" {attr_name} = {attr_value!r}\n")
|
||||
outfile.write("\n")
|
||||
|
||||
|
||||
class OpGenerator:
|
||||
|
||||
def __init__(self, reg_ops, ods_emitter, impl_emitter):
|
||||
super().__init__()
|
||||
self.reg_ops = reg_ops
|
||||
self.ods_emitter = ods_emitter
|
||||
self.impl_emitter = impl_emitter
|
||||
|
||||
def print_banner(self, text):
|
||||
for em in (self.ods_emitter, self.impl_emitter):
|
||||
em.print(
|
||||
"// -----------------------------------------------------------------------------"
|
||||
)
|
||||
em.print(f"// {text}")
|
||||
em.print(
|
||||
"// -----------------------------------------------------------------------------"
|
||||
)
|
||||
em.print("")
|
||||
|
||||
def ordinary_binary_op(self, kernel_sig, ods_name, op_name):
|
||||
""""Binary"-ops. These ops typically have:
|
||||
- '.Tensor' variant where the second arg is a Tensor
|
||||
- '.Scalar' variant where the second arg is a Scalar
|
||||
- An '.out' variant which contains a final "outref" argument
|
||||
Actual suffixes vary and the rules are set up to match anything
|
||||
that comes in in one of these forms.
|
||||
In addition, most of these have in-place versions, possibly of the
|
||||
'.Tensor' and '.Scalar' variants above.
|
||||
Note that many of these have more than two arguments (i.e. alpha/beta
|
||||
scalars trailing and such), but they are not relevant for
|
||||
matching/conversions (if they are more than pass-through, then have a
|
||||
dedicated rule).
|
||||
We generally canonicalize all of these forms to a single recognized op
|
||||
by:
|
||||
- Enabling the flag to promotTrailingOutTensor
|
||||
- Enabling the flag matchInplaceVariant
|
||||
- Setting all arguments and returns to kImmutableTensor
|
||||
- Enabling kPromoteScalarToTensor on the second argument.
|
||||
"""
|
||||
reg_record = self._get_reg_record(kernel_sig)
|
||||
ods_ins, arg_type_flags = self._map_sigtypes(
|
||||
reg_record["arguments"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
"Tensor:1": "AnyTorchImmutableTensor",
|
||||
"Scalar:1": "AnyTorchImmutableTensor",
|
||||
"Scalar": "AnyTorchScalarType",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
":1": ["kImmutableTensor", "kPromoteScalar"],
|
||||
})
|
||||
ods_outs, return_type_flags = self._map_sigtypes(
|
||||
reg_record["returns"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
self.ods_emitter.emit_opdef(ods_name,
|
||||
op_name,
|
||||
reg_record,
|
||||
ods_ins=ods_ins,
|
||||
ods_outs=ods_outs,
|
||||
traits=["NoSideEffect"])
|
||||
self.impl_emitter.emit_kernel_methods(ods_name,
|
||||
reg_record,
|
||||
arg_type_flags=arg_type_flags,
|
||||
return_type_flags=return_type_flags,
|
||||
promote_trailing_out_tensor=True)
|
||||
|
||||
def ordinary_unary_op(self, kernel_sig, ods_name, op_name):
|
||||
"""Unary ops.
|
||||
|
||||
These take and return a tensor and typically have an out and inplace
|
||||
variant (they may not but we generate patterns to match anyway).
|
||||
"""
|
||||
reg_record = self._get_reg_record(kernel_sig)
|
||||
ods_ins, arg_type_flags = self._map_sigtypes(
|
||||
reg_record["arguments"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
ods_outs, return_type_flags = self._map_sigtypes(
|
||||
reg_record["returns"],
|
||||
type_transforms={
|
||||
"Tensor:0": "AnyTorchImmutableTensor",
|
||||
},
|
||||
flag_transforms={
|
||||
":0": ["kImmutableTensor"],
|
||||
})
|
||||
self.ods_emitter.emit_opdef(ods_name,
|
||||
op_name,
|
||||
reg_record,
|
||||
ods_ins=ods_ins,
|
||||
ods_outs=ods_outs,
|
||||
traits=["NoSideEffect"])
|
||||
self.impl_emitter.emit_kernel_methods(ods_name,
|
||||
reg_record,
|
||||
arg_type_flags=arg_type_flags,
|
||||
return_type_flags=return_type_flags,
|
||||
promote_trailing_out_tensor=True)
|
||||
|
||||
def _get_reg_record(self, kernel_sig):
|
||||
"""Gets the op-dict for a given registered op name.
|
||||
|
||||
Args:
|
||||
kernel_sig: Signature of the kernel to find.
|
||||
Returns:
|
||||
Dict of the registration record.
|
||||
"""
|
||||
record = self.reg_ops.get(kernel_sig)
|
||||
if record:
|
||||
return record
|
||||
|
||||
# Try to give a nice "did you mean" style error, since this happens
|
||||
# so much.
|
||||
kernel_name, *rest = kernel_sig.split("(", maxsplit=1)
|
||||
dym_list = [k for k in self.reg_ops.keys() if k.startswith(kernel_name)]
|
||||
dym_message = '\n '.join(dym_list)
|
||||
raise ValueError(f"Could not find registry op matching '{kernel_sig}'. "
|
||||
f"Possible matches:\n {dym_message}")
|
||||
|
||||
def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str],
|
||||
flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]:
|
||||
"""Maps a list of signature entries to ods dags and flag lists.
|
||||
|
||||
The torch signature list contains dicts that minimally have keys 'name' and
|
||||
'type', representing torch names and types. Returns a corresponding
|
||||
list of 2-tuples of (ods_name, ods_type).
|
||||
|
||||
The provided type_transforms is a dict of type substitutions, one of which
|
||||
must match for each entry in the list. The keys can be either a verbatim
|
||||
torch type (i.e. "Tensor") an index in the list (i.e. ":0") or a type and
|
||||
index (i.e. "Tensor:0").
|
||||
|
||||
Similarly, flag_transforms matches its keys in the same way and maps to
|
||||
a list of KernelValueConversion constants that make up arg/return specific
|
||||
conversion flags.
|
||||
|
||||
Returns:
|
||||
- An ods dag list of (ods_name, ods_type) tuples
|
||||
- List of (torch_type, [conversion_flag]) for specifying conversions.
|
||||
"""
|
||||
# Generate to ods dag list.
|
||||
ods_dag_list = []
|
||||
for i, sigitem in enumerate(siglist):
|
||||
torch_name = sigitem["name"]
|
||||
torch_type = sigitem["type"]
|
||||
# Look up the type transform.
|
||||
ods_type = (type_transforms.get(f"{torch_type}:{i}") or
|
||||
type_transforms.get(f":{i}") or
|
||||
type_transforms.get(torch_type))
|
||||
if not ods_type:
|
||||
raise ValueError(f"Signature item {i}, type {torch_type} did not match "
|
||||
f"a type transform {type_transforms}")
|
||||
ods_dag_list.append((torch_name, ods_type))
|
||||
|
||||
# Generate the type conversion flags.
|
||||
type_flag_list = []
|
||||
for i, sigitem in enumerate(siglist):
|
||||
torch_type = sigitem["type"]
|
||||
# Look up the type transform.
|
||||
flags = (flag_transforms.get(f"{torch_type}:{i}") or
|
||||
flag_transforms.get(f":{i}") or flag_transforms.get(torch_type))
|
||||
if not flags:
|
||||
flags = []
|
||||
type_flag_list.append((torch_type, flags))
|
||||
return ods_dag_list, type_flag_list
|
||||
|
||||
|
||||
class EmitterBase:
|
||||
_INDENT = " "
|
||||
|
||||
def __init__(self, out: TextIO):
|
||||
super().__init__()
|
||||
self.out = out
|
||||
self.indent_level = 0
|
||||
|
||||
@contextmanager
|
||||
def indent(self, level=1):
|
||||
self.indent_level += level
|
||||
yield
|
||||
self.indent_level -= level
|
||||
assert self.indent_level >= 0, "Unbalanced indentation"
|
||||
|
||||
def print(self, s, *, end="\n", indent=True):
|
||||
if indent and self.indent_level:
|
||||
self.out.write(self._INDENT * self.indent_level)
|
||||
self.out.write(s)
|
||||
self.out.write(end)
|
||||
|
||||
def quote(self, s: str):
|
||||
s = s.replace(r'"', r'\\"')
|
||||
return f'"{s}"'
|
||||
|
||||
def quote_multiline_docstring(self, s: str, indent_level: int = 0):
|
||||
# TODO: Possibly find a python module to markdown the docstring for better
|
||||
# document generation.
|
||||
# Unlikely to contain the delimitter and since just a docstring, be safe.
|
||||
s = s.replace("}]", "")
|
||||
# Strip each line.
|
||||
s = "\n".join([l.rstrip() for l in s.splitlines()])
|
||||
indent = self._INDENT * indent_level
|
||||
s = textwrap.indent(s, indent + self._INDENT)
|
||||
return "[{\n" + s + "\n" + indent + "}]"
|
||||
|
||||
|
||||
class OdsEmitter(EmitterBase):
|
||||
ods_def_prefix = "aten_"
|
||||
ods_def_suffix = ""
|
||||
ods_template_name = "aten_Op"
|
||||
|
||||
def emit_opdef(self,
|
||||
ods_def_name: str,
|
||||
mnemonic: str,
|
||||
reg_record: Dict,
|
||||
ods_ins: List[Tuple[str, str]],
|
||||
ods_outs: List[Tuple[str, str]],
|
||||
traits: Sequence[str] = (),
|
||||
summary: Optional[str] = None):
|
||||
# Def first-line.
|
||||
full_traits = list(traits)
|
||||
full_traits.append(
|
||||
"DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>")
|
||||
full_traits.append("DeclareOpInterfaceMethods<TorchKernelOpInterface>")
|
||||
identifier = f"{self.ods_def_prefix}{ods_def_name}{self.ods_def_suffix}"
|
||||
self.print(f"def {identifier}: {self.ods_template_name}"
|
||||
f"<{self.quote(mnemonic)}, ["
|
||||
f"{', '.join(full_traits)}"
|
||||
f"]> {{")
|
||||
with self.indent():
|
||||
# Summary.
|
||||
if not summary:
|
||||
summary = f"Recognized op for kernel {reg_record['name'][0]}"
|
||||
self.print(f"let summary = {self.quote(summary)};")
|
||||
# Arguments.
|
||||
self.print("let arguments = (ins")
|
||||
with self.indent():
|
||||
self._emit_dag_list_body(ods_ins)
|
||||
self.print(");")
|
||||
|
||||
# Results.
|
||||
self.print("let results = (outs")
|
||||
with self.indent():
|
||||
self._emit_dag_list_body(ods_outs)
|
||||
self.print(");")
|
||||
|
||||
# Def last-line.
|
||||
self.print("}\n")
|
||||
|
||||
def _emit_dag_list_body(self, items):
|
||||
"""Emits a dag of (name, type) pairs."""
|
||||
for index, (ods_name, ods_type) in enumerate(items):
|
||||
is_last = index == len(items) - 1
|
||||
ods_namespec = f":${ods_name}" if ods_name else ""
|
||||
self.print(f"{ods_type}{ods_namespec}", end="\n" if is_last else ",\n")
|
||||
|
||||
|
||||
class CCImplEmitter(EmitterBase):
|
||||
|
||||
def emit_kernel_methods(self,
|
||||
ods_def_name: str,
|
||||
reg_record,
|
||||
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
promote_trailing_out_tensor=False):
|
||||
# getTorchKernelMetadata() method.
|
||||
self.print(
|
||||
f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{")
|
||||
with self.indent():
|
||||
self.print("return getTorchBuildKernelMetadata();")
|
||||
self.print("}\n")
|
||||
|
||||
# getTorchBuildKernelMetadata() method.
|
||||
kernel_name = reg_record["name"][0]
|
||||
self.print(
|
||||
f"const Torch::BuildKernelMetadata &{ods_def_name}::getTorchBuildKernelMetadata() {{"
|
||||
)
|
||||
with self.indent():
|
||||
self.print("using KVC = Torch::KernelValueConversion::BitMask;")
|
||||
self.print("static Torch::BuildKernelMetadata metadata = ([]() {")
|
||||
with self.indent():
|
||||
self.print("Torch::BuildKernelMetadata m;")
|
||||
self.print(f"m.kernelName = {self.quote(kernel_name)};")
|
||||
if promote_trailing_out_tensor:
|
||||
self.print("m.promoteTrailingOutTensor = true;")
|
||||
# Arg types/flags.
|
||||
arg_types = self._format_cpp_str_initlist(
|
||||
[t[0] for t in arg_type_flags])
|
||||
self.print(f"m.addArgTypes({arg_types});")
|
||||
arg_flags = self._format_cpp_kvc_initlist(
|
||||
[t[1] for t in arg_type_flags])
|
||||
self.print(f"m.addArgConversions({arg_flags});")
|
||||
# Returns types/flags.
|
||||
ret_types = self._format_cpp_str_initlist(
|
||||
[t[0] for t in return_type_flags])
|
||||
self.print(f"m.addReturnTypes({ret_types});")
|
||||
ret_flags = self._format_cpp_kvc_initlist(
|
||||
[t[1] for t in return_type_flags])
|
||||
self.print(f"m.addReturnConversions({ret_flags});")
|
||||
self.print("return m;")
|
||||
self.print("})();")
|
||||
self.print("return metadata;")
|
||||
self.print("}")
|
||||
|
||||
def _format_cpp_str_initlist(self, strings):
|
||||
quoted = [self.quote(s) for s in strings]
|
||||
joined = ", ".join(quoted)
|
||||
return "{" + joined + "}"
|
||||
|
||||
def _format_cpp_kvc_initlist(self, const_name_lists):
|
||||
|
||||
def or_flags(flag_names):
|
||||
if not flag_names:
|
||||
return "KVC::kNone"
|
||||
return "|".join([f"KVC::{n}" for n in flag_names])
|
||||
|
||||
or_d = [or_flags(l) for l in const_name_lists]
|
||||
joined = ", ".join(or_d)
|
||||
return "{" + joined + "}"
|
||||
|
||||
|
||||
def snakecase_to_camelcase(ident: str):
|
||||
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
||||
|
||||
|
||||
def _load_ops_as_dict():
|
||||
# Returns a list of dicts, each with a name that is a tuple of the form:
|
||||
# (kernel_signature, variant)
|
||||
# The kernel signature is a reified form of the argument type signature
|
||||
# used throughout PyTorch:
|
||||
# namespace::kernel_name(type1,type2)
|
||||
def reify_signature(reg_op):
|
||||
kernel_name, unused_variant = reg_op["name"]
|
||||
arg_types = [arg["type"] for arg in reg_op["arguments"]]
|
||||
return f"{kernel_name}({','.join(arg_types)})"
|
||||
|
||||
reg_ops_list = get_registered_ops()
|
||||
return {reify_signature(reg_op): reg_op for reg_op in reg_ops_list}
|
||||
|
||||
|
||||
def _get_main_module_name():
|
||||
return sys.modules["__main__"].__loader__.name
|
||||
|
||||
|
||||
ODS_BANNER = "\n".join([
|
||||
"//===-------------------------------------------------------*- tablegen -*-===//",
|
||||
"//",
|
||||
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
|
||||
"// See https://llvm.org/LICENSE.txt for license information.",
|
||||
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
|
||||
"//",
|
||||
"// Operation summaries and descriptions were systematically derived from public",
|
||||
"// API docstrings and are licensed accordingly:",
|
||||
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"// This file is automatically generated. Please do not edit.",
|
||||
"// Generated via:",
|
||||
f"// python -m {_get_main_module_name()}",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"",
|
||||
"",
|
||||
])
|
||||
|
||||
CC_IMPL_BANNER = "\n".join([
|
||||
"//===-------------------------------------------------------------*- cc -*-===//",
|
||||
"//",
|
||||
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
|
||||
"// See https://llvm.org/LICENSE.txt for license information.",
|
||||
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception", "//",
|
||||
"// Operation summaries and descriptions were systematically derived from public",
|
||||
"// API docstrings and are licensed accordingly:",
|
||||
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"// This file is automatically generated. Please do not edit.",
|
||||
"// Generated via:", f"// python -m {_get_main_module_name()}",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"", "", "// clang-format off"
|
||||
])
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _create_argparse()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -25,44 +25,7 @@ class aten_Op<string mnemonic, list<OpTrait> traits = [StatisticsOpInterface]> :
|
|||
|
||||
// Most ops are automatically generated from pytorch specs.
|
||||
include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td"
|
||||
|
||||
def aten_AddOp: aten_Op<"add", [
|
||||
NoSideEffect, TorchBuildableKernelOpInterface, TorchKernelOpInterface,
|
||||
StatisticsOpInterface]> {
|
||||
let arguments = (
|
||||
ins AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$other,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs AnyTorchImmutableTensor);
|
||||
let summary = "aten add operator";
|
||||
let description = [{
|
||||
AddOp
|
||||
aten add operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
|
||||
Torch::KernelMetadata getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
static const Torch::BuildKernelMetadata &getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::add";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Scalar"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
}];
|
||||
}
|
||||
include "npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td"
|
||||
|
||||
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> {
|
||||
|
|
|
@ -0,0 +1,781 @@
|
|||
//===-------------------------------------------------------------*- cc -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
// Operation summaries and descriptions were systematically derived from public
|
||||
// API docstrings and are licensed accordingly:
|
||||
// https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is automatically generated. Please do not edit.
|
||||
// Generated via:
|
||||
// python -m torch_mlir_utils.codegen.torch_signature_ods_gen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
// clang-format off
|
||||
// -----------------------------------------------------------------------------
|
||||
// Binary arithmetic ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Torch::KernelMetadata AddOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::add";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Scalar"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::atan2";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::div";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::floor_divide";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::mul";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::remainder";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::true_divide";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
// -----------------------------------------------------------------------------
|
||||
// Unary arithmetic ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Torch::KernelMetadata AbsOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::abs";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::acos";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::angle";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::asin";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::atan";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::ceil";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::conj";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::cos";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::cosh";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::digamma";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::erf";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::erfc";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::erfinv";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::exp";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::expm1";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::floor";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::frac";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::lgamma";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log10";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log1p";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::log2";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::neg";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::relu";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::reciprocal";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::round";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::rsqrt";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sigmoid";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sign";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sin";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sinh";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::sqrt";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::tan";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::tanh";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::trunc";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,654 @@
|
|||
// This file is (mostly) automatically generated. Please do not edit.
|
||||
//===- ATenOps.td ------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
|
||||
#define NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
|
||||
|
||||
def aten_AddUnderOp: aten_Op<"add_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$other,
|
||||
AnyScalar:$alpha
|
||||
);
|
||||
let summary = "aten add_ operator";
|
||||
let description = [{
|
||||
AddUnderOp
|
||||
aten add_ operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyType:$size,
|
||||
AnyType:$stride
|
||||
);
|
||||
let summary = "aten as_strided operator";
|
||||
let description = [{
|
||||
AsStridedOp
|
||||
aten as_strided operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyTensor:$bias,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation,
|
||||
AnyScalar:$transposed,
|
||||
AnyType:$output_padding,
|
||||
AnyScalar:$groups
|
||||
);
|
||||
let summary = "aten convolution_overrideable operator";
|
||||
let description = [{
|
||||
ConvolutionOverrideableOp
|
||||
aten convolution_overrideable operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor, AnyTensor, AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation,
|
||||
AnyScalar:$transposed,
|
||||
AnyType:$output_padding,
|
||||
AnyScalar:$groups
|
||||
);
|
||||
let summary = "aten convolution_backward_overrideable operator";
|
||||
let description = [{
|
||||
ConvolutionBackwardOverrideableOp
|
||||
aten convolution_backward_overrideable operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$other
|
||||
);
|
||||
let summary = "aten div_ operator";
|
||||
let description = [{
|
||||
DivUnderOp
|
||||
aten div_ operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyType:$size,
|
||||
AnyScalar:$implicit
|
||||
);
|
||||
let summary = "aten expand operator";
|
||||
let description = [{
|
||||
ExpandOp
|
||||
aten expand operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$dim,
|
||||
AnyScalar:$half_to_float
|
||||
);
|
||||
let summary = "aten _log_softmax operator";
|
||||
let description = [{
|
||||
LogSoftmaxOp
|
||||
aten _log_softmax operator
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$output,
|
||||
AnyScalar:$dim,
|
||||
AnyTensor:$self
|
||||
);
|
||||
let summary = "aten _log_softmax_backward_data operator";
|
||||
let description = [{
|
||||
LogSoftmaxBackwardDataOp
|
||||
aten _log_softmax_backward_data operator
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self
|
||||
);
|
||||
let summary = "aten mean operator";
|
||||
let description = [{
|
||||
MeanOp
|
||||
aten mean operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_MmOp: aten_Op<"mm", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$mat2
|
||||
);
|
||||
let summary = "aten mm operator";
|
||||
let description = [{
|
||||
MmOp
|
||||
aten mm operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_MulUnderOp: aten_Op<"mul_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$other
|
||||
);
|
||||
let summary = "aten mul_ operator";
|
||||
let description = [{
|
||||
MulUnderOp
|
||||
aten mul_ operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NativeBatchNormOp: aten_Op<"native_batch_norm", [NoSideEffect, StatisticsOpInterface]>,
|
||||
// FIXME: not quite automatically generated names
|
||||
Results<(outs AnyTensor:$output, AnyTensor:$save_mean, AnyTensor:$save_invstd)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyTensor:$bias,
|
||||
AnyTensor:$running_mean,
|
||||
AnyTensor:$running_var,
|
||||
AnyScalar:$training,
|
||||
AnyScalar:$momentum,
|
||||
AnyScalar:$eps
|
||||
);
|
||||
let summary = "aten native_batch_norm operator";
|
||||
let description = [{
|
||||
NativeBatchNormOp
|
||||
aten native_batch_norm operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NativeBatchNormBackwardOp: aten_Op<"native_batch_norm_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
// FIXME: not quite automatically generated
|
||||
Results<(outs AnyTensor:$dx, AnyTensor:$dm, AnyTensor:$dv)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_out,
|
||||
AnyTensor:$input,
|
||||
AnyTensor:$weight,
|
||||
AnyTensor:$running_mean,
|
||||
AnyTensor:$running_var,
|
||||
AnyTensor:$save_mean,
|
||||
AnyTensor:$save_invstd,
|
||||
AnyScalar:$train,
|
||||
AnyScalar:$eps
|
||||
);
|
||||
let summary = "aten native_batch_norm_backward operator";
|
||||
let description = [{
|
||||
NativeBatchNormBackwardOp
|
||||
aten native_batch_norm_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_ReluUnderOp: aten_Op<"relu_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self
|
||||
);
|
||||
let summary = "aten relu_ operator";
|
||||
let description = [{
|
||||
ReluUnderOp
|
||||
aten relu_ operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_SizeOp: aten_Op<"size", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$dim
|
||||
);
|
||||
let summary = "aten size operator";
|
||||
let description = [{
|
||||
SizeOp
|
||||
aten size operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_SqueezeOp: aten_Op<"squeeze", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$dim
|
||||
);
|
||||
let summary = "aten squeeze operator";
|
||||
let description = [{
|
||||
SqueezeOp
|
||||
aten squeeze operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_SumOp: aten_Op<"sum", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyType:$dim,
|
||||
AnyScalar:$keepdim
|
||||
);
|
||||
let summary = "aten sum operator";
|
||||
let description = [{
|
||||
SumOp
|
||||
aten sum operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_TOp: aten_Op<"t", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self
|
||||
);
|
||||
let summary = "aten t operator";
|
||||
let description = [{
|
||||
TOp
|
||||
aten t operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_ThresholdBackwardOp: aten_Op<"threshold_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self,
|
||||
AnyScalar:$threshold
|
||||
);
|
||||
let summary = "aten threshold_backward operator";
|
||||
let description = [{
|
||||
ThresholdBackwardOp
|
||||
aten threshold_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_UnsqueezeOp: aten_Op<"unsqueeze", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$dim
|
||||
);
|
||||
let summary = "aten unsqueeze operator";
|
||||
let description = [{
|
||||
UnsqueezeOp
|
||||
aten unsqueeze operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_SubOp: aten_Op<"sub", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$other,
|
||||
AnyScalar:$alpha
|
||||
);
|
||||
let summary = "aten sub operator";
|
||||
let description = [{
|
||||
SubOp
|
||||
aten sub operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_SubUnderOp: aten_Op<"sub_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$other,
|
||||
AnyScalar:$alpha
|
||||
);
|
||||
let summary = "aten sub_ operator";
|
||||
let description = [{
|
||||
SubUnderOp
|
||||
aten sub_ operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_AddmmOp: aten_Op<"addmm", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$mat1,
|
||||
AnyTensor:$mat2,
|
||||
AnyScalar:$beta,
|
||||
AnyScalar:$alpha
|
||||
);
|
||||
let summary = "aten addmm operator";
|
||||
let description = [{
|
||||
AddmmOp
|
||||
aten addmm operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_ViewOp: aten_Op<"view", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyType:$size
|
||||
);
|
||||
let summary = "aten view operator";
|
||||
let description = [{
|
||||
ViewOp
|
||||
aten view operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$dim,
|
||||
AnyTensor:$index,
|
||||
AnyScalar:$sparse_grad
|
||||
);
|
||||
let summary = "aten gather operator";
|
||||
let description = [{
|
||||
GatherOp
|
||||
aten gather operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor, AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index
|
||||
);
|
||||
let summary = "aten nll_loss_forward operator";
|
||||
let description = [{
|
||||
NllLossForwardOp
|
||||
aten nll_loss_forward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index,
|
||||
AnyTensor:$total_weight
|
||||
);
|
||||
let summary = "aten nll_loss_backward operator";
|
||||
let description = [{
|
||||
NllLossBackwardOp
|
||||
aten nll_loss_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor, AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index
|
||||
);
|
||||
let summary = "aten nll_loss2d_forward operator";
|
||||
let description = [{
|
||||
NllLoss2dForwardOp
|
||||
aten nll_loss2d_forward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self,
|
||||
AnyTensor:$target,
|
||||
AnyTensor:$weight,
|
||||
AnyScalar:$reduction,
|
||||
AnyScalar:$ignore_index,
|
||||
AnyTensor:$total_weight
|
||||
);
|
||||
let summary = "aten nll_loss2d_backward operator";
|
||||
let description = [{
|
||||
NllLoss2dBackwardOp
|
||||
aten nll_loss2d_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$min_val,
|
||||
AnyScalar:$max_val
|
||||
);
|
||||
let summary = "aten hardtanh operator";
|
||||
let description = [{
|
||||
HardtanhOp
|
||||
aten hardtanh operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_HardtanhBackwardOp: aten_Op<"hardtanh_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self,
|
||||
AnyScalar:$min_val,
|
||||
AnyScalar:$max_val
|
||||
);
|
||||
let summary = "aten hardtanh_backward operator";
|
||||
let description = [{
|
||||
HardtanhBackwardOp
|
||||
aten hardtanh_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_HardtanhUnderOp: aten_Op<"hardtanh_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyScalar:$min_val,
|
||||
AnyScalar:$max_val
|
||||
);
|
||||
let summary = "aten hardtanh_ operator";
|
||||
let description = [{
|
||||
HardtanhUnderOp
|
||||
aten hardtanh_ operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_AdaptiveAvgPool2dOp: aten_Op<"_adaptive_avg_pool2d", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyType:$output_size
|
||||
);
|
||||
let summary = "aten _adaptive_avg_pool2d operator";
|
||||
let description = [{
|
||||
AdaptiveAvgPool2dOp
|
||||
aten _adaptive_avg_pool2d operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_AdaptiveAvgPool2dBackwardOp: aten_Op<"_adaptive_avg_pool2d_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self
|
||||
);
|
||||
let summary = "aten _adaptive_avg_pool2d_backward operator";
|
||||
let description = [{
|
||||
AdaptiveAvgPool2dBackwardOp
|
||||
aten _adaptive_avg_pool2d_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_MaxPool2dWithIndicesOp: aten_Op<"max_pool2d_with_indices", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor, AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyType:$kernel_size,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation,
|
||||
AnyScalar:$ceil_mode
|
||||
);
|
||||
let summary = "aten max_pool2d_with_indices operator";
|
||||
let description = [{
|
||||
MaxPool2dWithIndicesOp
|
||||
aten max_pool2d_with_indices operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_MaxPool2dWithIndicesBackwardOp: aten_Op<"max_pool2d_with_indices_backward", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$grad_output,
|
||||
AnyTensor:$self,
|
||||
AnyType:$kernel_size,
|
||||
AnyType:$stride,
|
||||
AnyType:$padding,
|
||||
AnyType:$dilation,
|
||||
AnyScalar:$ceil_mode,
|
||||
AnyTensor:$indices
|
||||
);
|
||||
let summary = "aten max_pool2d_with_indices_backward operator";
|
||||
let description = [{
|
||||
MaxPool2dWithIndicesBackwardOp
|
||||
aten max_pool2d_with_indices_backward operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
|
|
@ -28,7 +28,11 @@ enum BitMask {
|
|||
// Coerce/require a mutable tensor value.
|
||||
kMutableTensor = 4,
|
||||
|
||||
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kMutableTensor)
|
||||
// If the source is a Scalar and the target is a Tensor, promotes
|
||||
// to a 0d tensor.
|
||||
kPromoteScalar = 8,
|
||||
|
||||
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar)
|
||||
};
|
||||
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
|
||||
} // namespace KernelValueConversion
|
||||
|
|
|
@ -108,3 +108,5 @@ void ATenDialect::initialize() {
|
|||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc"
|
||||
|
|
|
@ -55,33 +55,6 @@ std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// add
|
||||
std::map<std::string, uint64_t> AddOp::getStatistics() {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
TensorType resultTy = getResult().getType().cast<TensorType>();
|
||||
TensorType aType = getOperand(0).getType().cast<TensorType>();
|
||||
Type bType = getOperand(1).getType();
|
||||
|
||||
uint64_t ofm_volume = getTensorVolume(resultTy);
|
||||
|
||||
toReturn["ops:+"] = ofm_volume;
|
||||
toReturn["result:0:activation_out"] = ofm_volume;
|
||||
|
||||
// Find the size of the A and B operands
|
||||
uint64_t a_volume = getTensorVolume(aType);
|
||||
uint64_t b_volume = getTensorVolume(bType);
|
||||
|
||||
toReturn["operand:0:activation_in"] = a_volume;
|
||||
toReturn["operand:1:activation_in"] = b_volume;
|
||||
|
||||
toReturn["reads"] = a_volume + b_volume;
|
||||
toReturn["writes"] = ofm_volume;
|
||||
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// add_
|
||||
std::map<std::string, uint64_t> AddUnderOp::getStatistics() {
|
||||
|
||||
|
@ -231,33 +204,6 @@ ConvolutionBackwardOverrideableOp::getStatistics() {
|
|||
return getConv2dBackwardStatistics(*this, groups);
|
||||
}
|
||||
|
||||
// div
|
||||
std::map<std::string, uint64_t> DivOp::getStatistics() {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
TensorType resultTy = getResult().getType().cast<TensorType>();
|
||||
TensorType aType = getOperand(0).getType().cast<TensorType>();
|
||||
Type bType = getOperand(1).getType();
|
||||
|
||||
uint64_t ofm_volume = getTensorVolume(resultTy);
|
||||
toReturn["ops:/"] = ofm_volume;
|
||||
|
||||
toReturn["result:0:activation_out"] = ofm_volume;
|
||||
|
||||
// Find the size of the A and B operands
|
||||
uint64_t a_volume = getTensorVolume(aType);
|
||||
uint64_t b_volume = getTensorVolume(bType);
|
||||
|
||||
toReturn["operand:0:activation_in"] = a_volume;
|
||||
toReturn["operand:1:activation_in"] = b_volume;
|
||||
|
||||
toReturn["reads"] = a_volume + b_volume;
|
||||
toReturn["writes"] = ofm_volume;
|
||||
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// div_
|
||||
std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
|
||||
|
||||
|
@ -467,32 +413,6 @@ std::map<std::string, uint64_t> MmOp::getStatistics() {
|
|||
return getMMOpStatistics(*this);
|
||||
}
|
||||
|
||||
// mul
|
||||
std::map<std::string, uint64_t> MulOp::getStatistics() {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
TensorType resultTy = getResult().getType().cast<TensorType>();
|
||||
TensorType aType = getOperand(0).getType().cast<TensorType>();
|
||||
Type bType = getOperand(1).getType();
|
||||
|
||||
uint64_t ofm_volume = getTensorVolume(resultTy);
|
||||
toReturn["ops:*"] = ofm_volume;
|
||||
toReturn["result:0:activation_out"] = ofm_volume;
|
||||
|
||||
// Find the size of the A and B operands
|
||||
uint64_t a_volume = getTensorVolume(aType);
|
||||
uint64_t b_volume = getTensorVolume(bType);
|
||||
|
||||
toReturn["operand:0:activation_in"] = a_volume;
|
||||
toReturn["operand:1:activation_in"] = b_volume;
|
||||
|
||||
toReturn["reads"] = a_volume + b_volume;
|
||||
toReturn["writes"] = ofm_volume;
|
||||
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// mul_
|
||||
std::map<std::string, uint64_t> MulUnderOp::getStatistics() {
|
||||
|
||||
|
@ -668,23 +588,6 @@ std::map<std::string, uint64_t> NllLoss2dBackwardOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// neg op
|
||||
std::map<std::string, uint64_t> NegOp::getStatistics() {
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
auto insize = getTensorVolume(getOperand().getType());
|
||||
auto outsize = getTensorVolume(getResult().getType());
|
||||
toReturn["reads"] = toReturn["operand:0:activation_in"] = insize;
|
||||
toReturn["writes"] = toReturn["result:0:activation_out"] = outsize;
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// relu
|
||||
// std::map<std::string, uint64_t> ReLUOp::getStatistics() {
|
||||
// return getReLUOpStatistics(*this);
|
||||
// }
|
||||
std::map<std::string, uint64_t> ReluOp::getStatistics() {
|
||||
return getReLUOpStatistics(*this);
|
||||
}
|
||||
// std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() {
|
||||
// return getReLUOpStatistics(*this);
|
||||
// }
|
||||
|
|
|
@ -47,6 +47,7 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
|
|||
|
||||
// Immutable tensor conversion.
|
||||
if (flag & KVC::kImmutableTensor) {
|
||||
// TODO: Support the kPromoteScalar flag.
|
||||
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
|
||||
return None;
|
||||
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
||||
// CHECK-LABEL: "L0-add-0": {
|
||||
// CHECK-NEXT: "activation_in": 12,
|
||||
// CHECK-NEXT: "activation_out": 6,
|
||||
// CHECK-NEXT: "ops:+": 6,
|
||||
// CHECK-NEXT: "reads": 12,
|
||||
// CHECK-NEXT: "writes": 6
|
||||
|
||||
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
|
||||
// CHECK-CONVERSION-LABEL: @graph
|
||||
func @graph(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
%1 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%2 = "aten.add"(%arg0, %arg1, %1) : (tensor<1x2x3xf32>, tensor<1x2x3xf32>, i32) -> tensor<1x2x3xf32>
|
||||
"std.return"(%2) : (tensor<1x2x3xf32>) -> ()
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
||||
// CHECK-LABEL: "L0-relu-0": {
|
||||
// CHECK-NEXT: "activation_in": 6,
|
||||
// CHECK-NEXT: "activation_out": 6,
|
||||
// CHECK-NEXT: "ops:>": 6,
|
||||
// CHECK-NEXT: "reads": 6,
|
||||
// CHECK-NEXT: "writes": 6
|
||||
|
||||
module {
|
||||
func @graph(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
%0 = "aten.relu"(%arg0) : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
"std.return"(%0) : (tensor<1x2x3xf32>) -> ()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
||||
// CHECK-LABEL: "L0-native_batch_norm-0": {
|
||||
// CHECK-LABEL: "L1-relu-0": {
|
||||
// CHECK-LABEL: "L2-_convolution-0": {
|
||||
// CHECK-LABEL: "L3-native_batch_norm-1": {
|
||||
// CHECK-LABEL: "L4-relu-1": {
|
||||
// CHECK-LABEL: "L5-_convolution-1": {
|
||||
// CHECK-LABEL: "L6-native_batch_norm-2": {
|
||||
// CHECK-LABEL: "L7-relu-2": {
|
||||
// CHECK-LABEL: "L8-_convolution-2": {
|
||||
// CHECK-LABEL: "L9-add-0": {
|
||||
|
||||
module {
|
||||
func @graph(%arg0: tensor<1x16x128x128xf32>, %arg1: tensor<1x16x128x128xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>, %arg5: tensor<16xf32>, %arg6: tensor<8x16x1x1xf32>, %arg7: tensor<8xf32>, %arg8: tensor<8xf32>, %arg9: tensor<8xf32>, %arg10: tensor<8xf32>, %arg11: tensor<8xf32>, %arg12: tensor<8x8x3x3xf32>, %arg13: tensor<8xf32>, %arg14: tensor<8xf32>, %arg15: tensor<8xf32>, %arg16: tensor<8xf32>, %arg17: tensor<8xf32>, %arg18: tensor<16x8x1x1xf32>, %arg19: tensor<16xf32>) -> tensor<1x16x128x128xf32> {
|
||||
%0 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%1 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
|
||||
%2 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
|
||||
%3:3 = "aten.native_batch_norm"(%arg1, %arg2, %arg3, %arg4, %arg5, %0, %1, %2) : (tensor<1x16x128x128xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i1, f32, f32) -> (tensor<1x16x128x128xf32>, tensor<16xf32>, tensor<16xf32>)
|
||||
%4 = "aten.relu"(%3#0) : (tensor<1x16x128x128xf32>) -> tensor<1x16x128x128xf32>
|
||||
%5 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%6 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%7 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%8 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%9 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%10 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%11 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%12 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%13 = "aten._convolution"(%4, %arg6, %arg7, %5, %6, %7) : (tensor<1x16x128x128xf32>, tensor<8x16x1x1xf32>, tensor<8xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x8x128x128xf32>
|
||||
%14 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%15 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
|
||||
%16 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
|
||||
%17:3 = "aten.native_batch_norm"(%13, %arg8, %arg9, %arg10, %arg11, %14, %15, %16) : (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, i1, f32, f32) -> (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
%18 = "aten.relu"(%17#0) : (tensor<1x8x128x128xf32>) -> tensor<1x8x128x128xf32>
|
||||
%19 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%20 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%21 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%22 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%23 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%24 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%25 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%26 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%27 = "aten._convolution"(%18, %arg12, %arg13, %19, %20, %21) : (tensor<1x8x128x128xf32>, tensor<8x8x3x3xf32>, tensor<8xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x8x128x128xf32>
|
||||
%28 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%29 = "aten.constant"() {type = "f32", value = 1.000000e-01 : f32} : () -> f32
|
||||
%30 = "aten.constant"() {type = "f32", value = 9.99999974E-6 : f32} : () -> f32
|
||||
%31:3 = "aten.native_batch_norm"(%27, %arg14, %arg15, %arg16, %arg17, %28, %29, %30) : (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, i1, f32, f32) -> (tensor<1x8x128x128xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
%32 = "aten.relu"(%31#0) : (tensor<1x8x128x128xf32>) -> tensor<1x8x128x128xf32>
|
||||
%33 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%34 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%35 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%36 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%37 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
||||
%38 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%39 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
||||
%40 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
||||
%41 = "aten._convolution"(%32, %arg18, %arg19, %33, %34, %35) : (tensor<1x8x128x128xf32>, tensor<16x8x1x1xf32>, tensor<16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x16x128x128xf32>
|
||||
%42 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
||||
%43 = "aten.add"(%arg0, %41, %42) : (tensor<1x16x128x128xf32>, tensor<1x16x128x128xf32>, i32) -> tensor<1x16x128x128xf32>
|
||||
return %43 : tensor<1x16x128x128xf32>
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue