Reduce compilation time for TorchOps.cpp.inc

The `assemblyFormat` stuff (which generates unrolled, per-op C++ code)
was taking up a lot of compile time, and all the ops are essentially
printed with the same logic. So this PR makes them all call the same
helper function. This is done by using
`let hasCustomAssemblyFormat = 1` and then implementing `FooOp::parse`
and `FooOp::print`.

Additionally, the `Generated*Ops.td` files are all collapsed into just
`GeneratedTorchOps.td` (there is no reason to have the files separate,
since the files are very large anyway so one is always having to search
within them -- editors don't care that the file to search is now a bit
bigger :) ).

This reduces TorchOpsODSGenerated.cpp compile time (which is now
GeneratedTorchOps.cpp) from 39 to 31 seconds on my machine. This is
actually less than I expected, but this PR is an overall cleanup to the
code anyway. The next step will be to introduce (better) functionality
upstream for sharding the TorchOps.cpp.inc file, so that we can truly
parallelize the O(#ops) costs. This is also necessary, because after
this PR, TorchDialect.cpp is now the slowest file to compile, due to the
`addOperations<... all the ops ...>` call, which needs to be shareded
too.
pull/605/head snapshot-20220321.338
Sean Silva 2022-03-18 21:04:47 +00:00
parent 5b9bdfaf3f
commit 729402c3f4
11 changed files with 3314 additions and 953 deletions

View File

@ -1,14 +1,14 @@
#!/bin/bash
# Updates auto-generated ODS files for the `torch` dialect.
set -e
set -euo pipefail
src_dir="$(realpath $(dirname $0)/..)"
build_dir="$(realpath "${TORCH_MLIR_BUILD_DIR:-$src_dir/build}")"
torch_ir_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR"
torch_ir_include_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR"
python_packages_dir="${build_dir}/tools/torch-mlir/python_packages"
#ninja -C "${build_dir}"
PYTHONPATH="${python_packages_dir}/torch_mlir" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
--torch_ir_dir="${torch_ir_dir}" \
--debug_registry_dump="${torch_ir_dir}/JITOperatorRegistryDump.txt"
--torch_ir_include_dir="${torch_ir_include_dir}" \
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"

View File

@ -1,249 +0,0 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
// Operation summaries and descriptions were systematically derived from public
// API docstrings and are licensed accordingly:
// https://github.com/pytorch/pytorch/blob/master/LICENSE
//===----------------------------------------------------------------------===//
//
// This file is automatically generated. Please do not edit.
// Generated via:
// python -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen
//
//===----------------------------------------------------------------------===//
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::layout : (Tensor) -> (int)`";
let arguments = (ins
AnyTorchTensorType:$a
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
}
def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::TupleIndex : (Any, int) -> (Any)`";
let arguments = (ins
AnyTorchType:$tup,
Torch_IntType:$i
);
let results = (outs
AnyTorchType:$result
);
let assemblyFormat = "$tup `,` $i attr-dict `:` qualified(type($tup)) `,` qualified(type($i)) `->` qualified(type($result))";
let hasCanonicalizer = 1;
}
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::device : (Tensor) -> (Device)`";
let arguments = (ins
AnyTorchTensorType:$a
);
let results = (outs
Torch_DeviceType:$result
);
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
}
def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::dtype : (Tensor) -> (int)`";
let arguments = (ins
AnyTorchTensorType:$a
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
let hasFolder = 1;
}
def Torch_PrimTupleUnpackOp : Torch_Op<"prim.TupleUnpack", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `prim::TupleUnpack : (Any) -> (...)`";
let arguments = (ins
AnyTorchType:$tup
);
let results = (outs
Variadic<AnyTorchType>:$results
);
let assemblyFormat = "$tup attr-dict `:` qualified(type($tup)) `->` qualified(type($results))";
let hasCanonicalizer = 1;
}
def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::NumToTensor.Scalar : (Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchScalarType:$a
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
}
def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::min.self_int : (int[]) -> (int)`";
let arguments = (ins
TorchIntListType:$self
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
let hasFolder = 1;
}
def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::min.int : (int, int) -> (int)`";
let arguments = (ins
Torch_IntType:$a,
Torch_IntType:$b
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
}
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::max.self_int : (int[]) -> (int)`";
let arguments = (ins
TorchIntListType:$self
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_PrimMaxIntOp : Torch_Op<"prim.max.int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::max.int : (int, int) -> (int)`";
let arguments = (ins
Torch_IntType:$a,
Torch_IntType:$b
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
let hasFolder = 1;
}
def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::RaiseException : (str, str?) -> ()`";
let arguments = (ins
Torch_StringType:$msg,
TorchOptionalStringType:$cls
);
let results = (outs
);
let assemblyFormat = "$msg `,` $cls attr-dict `:` qualified(type($msg)) `,` qualified(type($cls))";
}
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
NoSideEffect,
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prim::Uninitialized : () -> (Any)`";
let arguments = (ins
);
let results = (outs
AnyTorchType:$result
);
let assemblyFormat = " attr-dict `:` qualified(type($result))";
let hasCanonicalizer = 1;
}
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `prim::unchecked_cast : (t) -> (t)`";
let arguments = (ins
AnyTorchType:$x
);
let results = (outs
AnyTorchType:$result
);
let assemblyFormat = "$x attr-dict `:` qualified(type($x)) `->` qualified(type($result))";
let hasFolder = 1;
}
def Torch_PrimPrintOp : Torch_Op<"prim.Print", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `prim::Print : (...) -> ()`";
let arguments = (ins
Variadic<AnyTorchType>:$operands
);
let results = (outs
);
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands))";
}
def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `prim::tolist : (...) -> (...)`";
let arguments = (ins
Variadic<AnyTorchType>:$operands
);
let results = (outs
Variadic<AnyTorchType>:$results
);
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
}

View File

@ -1,37 +0,0 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
// Operation summaries and descriptions were systematically derived from public
// API docstrings and are licensed accordingly:
// https://github.com/pytorch/pytorch/blob/master/LICENSE
//===----------------------------------------------------------------------===//
//
// This file is automatically generated. Please do not edit.
// Generated via:
// python -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen
//
//===----------------------------------------------------------------------===//
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
HasValueSemantics,
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$X,
Torch_LinearParamsType:$W_prepack,
Torch_FloatType:$Y_scale_i,
Torch_IntType:$Y_zero_point_i
);
let results = (outs
AnyTorchTensorType:$Y
);
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` qualified(type($X)) `,` qualified(type($W_prepack)) `,` qualified(type($Y_scale_i)) `,` qualified(type($Y_zero_point_i)) `->` qualified(type($Y))";
}

View File

@ -22,9 +22,7 @@ class Torch_Op<string mnemonic, list<Trait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> {
}
include "torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td"
include "torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td"
include "torch-mlir/Dialect/Torch/IR/GeneratedQuantizedOps.td"
include "torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td"
//===----------------------------------------------------------------------===//
// TorchScript `torch.nn.Module` object instantiation ops.

View File

@ -3,6 +3,7 @@ add_mlir_library(TorchMLIRTorchDialect
TorchOps.cpp
TorchOpsODSGenerated.cpp
TorchTypes.cpp
UtilsForODSGenerated.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch

View File

@ -17,6 +17,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "UtilsForODSGenerated.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"

View File

@ -0,0 +1,80 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
//
// This file contains utilities referenced by ODS generated code.
//
// The utilities defined here are only meant for use by ODS generated code.
// If something is of wider use, then it should be moved elsewhere.
//
//===----------------------------------------------------------------------===//
#include "UtilsForODSGenerated.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
OperationState &result, int numOperands,
int numResults) {
llvm::SMLoc loc = parser.getCurrentLocation();
SmallVector<OpAsmParser::OperandType> operands;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands))
return failure();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
if (parser.parseColon())
return failure();
if (numOperands > 0) {
SmallVector<Type> operandTypes;
if (parser.parseTypeList(operandTypes))
return failure();
if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
return failure();
}
if (numOperands > 0 && numResults > 0) {
if (parser.parseArrow())
return failure();
}
if (numResults > 0) {
if (parser.parseTypeList(result.types))
return failure();
}
return success();
}
void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
int numResults) {
p << ' ';
llvm::interleaveComma(op->getOperands(), p);
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
p << " : ";
if (numOperands > 0) {
p << ' ';
llvm::interleaveComma(op->getOperandTypes(), p);
}
if (numOperands > 0 && numResults > 0) {
p << " -> ";
}
if (numResults > 0) {
p << ' ';
llvm::interleaveComma(op->getResultTypes(), p);
}
}

View File

@ -0,0 +1,43 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
//
// This file contains utilities referenced by GeneratedTorchOps.cpp.
//
// The utilities defined here are only meant for use by GeneratedTorchOps.cpp.
// If something is of wider use, then it should be moved elsewhere.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
namespace mlir {
namespace torch {
namespace Torch {
// Parse a generated Torch op in the default format.
ParseResult parseDefaultTorchOp(OpAsmParser &parser, OperationState &result,
int numOperands, int numResults);
// Print a generated Torch op in the default format.
void printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
int numResults);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -136,9 +136,9 @@ class JitOperator:
for op_name_atom in op_name_atoms:
for s in op_name_atom.split("_"):
op_class_name_atoms.append(s if s else "_")
td_def_name = "Torch_" + "".join(
cpp_class_name = "".join(
uppercase_first_letter(s) for s in op_class_name_atoms) + "Op"
return op_name, td_def_name
return op_name, cpp_class_name
def get_shape_function_signature(self):
"""Gets the Python function signature for this op's shape function.
@ -197,9 +197,9 @@ class JitOperator:
# Emit the MLIR names to allow easy reverse lookup if starting
# from an unregistered op.
op_name, td_def_name = self.get_mlir_names()
op_name, cpp_class_name = self.get_mlir_names()
p(f"MLIR op name = torch.{op_name}")
p(f"MLIR td def name = {td_def_name}")
p(f"MLIR cpp class name = {cpp_class_name}")
p(f"namespace = {self.namespace}")
p(f"unqualified_name = {self.unqualified_name}")

View File

@ -65,30 +65,33 @@ def _get_main_module_name() -> str:
# pytype: enable=attribute-error
ODS_BANNER = "\n".join([
"//===-------------------------------------------------------*- tablegen -*-===//",
"//",
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
"// See https://llvm.org/LICENSE.txt for license information.",
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
"// Also available under a BSD-style license. See LICENSE.",
"//",
"// Operation summaries and descriptions were systematically derived from public",
"// API docstrings and are licensed accordingly:",
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
"//===----------------------------------------------------------------------===//",
"//",
"// This file is automatically generated. Please do not edit.",
"// Generated via:",
f"// python -m {_get_main_module_name()}",
"//",
"//===----------------------------------------------------------------------===//",
"",
"",
])
ODS_BANNER = f"""//===-------------------------------------------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
// Operation summaries and descriptions were systematically derived from public
// API docstrings and are licensed accordingly:
// https://github.com/pytorch/pytorch/blob/master/LICENSE
//===----------------------------------------------------------------------===//
//
// This file is automatically generated. Please do not edit.
// Generated via:
// ```
// python -m {_get_main_module_name()}
// ```
//
//===----------------------------------------------------------------------===//
def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
"""
def raw_emit_op(operator: JitOperator,
emitter_td: TextEmitter,
*, traits: List[str],
has_folder: bool, has_canonicalizer: bool):
"""Emit the ODS for a JitOperator to a textual file.
@ -98,73 +101,87 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
You probably don't want to call this directly.
"""
emitter = TextEmitter(f)
p = lambda *args: emitter.print(*args)
op_name, td_def_name = operator.get_mlir_names()
p_td = lambda *args: emitter_td.print(*args)
op_name, cpp_class_name = operator.get_mlir_names()
# Generate unique result names for ops with nameless results
multiple_results = len(operator.returns) > 1
generic_result_name = lambda i: "result" + (str(i) if multiple_results else "")
p(f"def {td_def_name} : Torch_Op<{emitter.quote(op_name)}, [")
with emitter.indent():
with emitter.indent():
p(",\n".join(traits))
p("]> {")
with emitter.indent():
def generic_result_name(i):
return "result" + (str(i) if multiple_results else "")
p_td(
f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [")
with emitter_td.indent():
with emitter_td.indent():
p_td(",\n".join(traits))
p_td("]> {")
with emitter_td.indent():
summary = f"Generated op for `{operator.unique_key}`"
p(f"let summary = {emitter.quote(summary)};")
p(f"let arguments = (ins")
with emitter.indent():
p_td(f"let summary = {emitter_td.quote(summary)};")
p_td(f"let arguments = (ins")
with emitter_td.indent():
if operator.is_vararg:
p("Variadic<AnyTorchType>:$operands")
p_td("Variadic<AnyTorchType>:$operands")
else:
p(",\n".join([
p_td(",\n".join([
f"""{get_ods_type(arg["type"])}:${arg["name"]}"""
for arg in operator.arguments
]))
p(");")
p(f"let results = (outs")
with emitter.indent():
p_td(");")
p_td(f"let results = (outs")
with emitter_td.indent():
if operator.is_varret:
p("Variadic<AnyTorchType>:$results")
p_td("Variadic<AnyTorchType>:$results")
else:
p(",\n".join([
p_td(",\n".join([
f"""{get_ods_type(ret["type"])}:${ret["name"] or generic_result_name(e)}"""
for e, ret in enumerate(operator.returns)
]))
p(");")
p_td(");")
if operator.is_vararg:
assembly_operands = "`(` $operands `)`"
assembly_operand_types = "qualified(type($operands))"
if operator.is_vararg or operator.is_varret:
if operator.is_vararg:
assembly_operands = "`(` $operands `)`"
assembly_operand_types = "qualified(type($operands))"
else:
assembly_operands = " `,` ".join("$" + arg["name"]
for arg in operator.arguments)
assembly_operand_types = " `,` ".join(
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments)
if operator.is_varret:
assembly_result_types = "qualified(type($results))"
else:
assembly_result_types = " `,` ".join(
f"""qualified(type(${ret["name"] or generic_result_name(e)}))"""
for e, ret in enumerate(operator.returns))
if assembly_operand_types and assembly_result_types:
maybe_arrow = " `->` "
else:
maybe_arrow = ""
assembly_format = f"{assembly_operands} attr-dict `:` {assembly_operand_types}{maybe_arrow}{assembly_result_types}"
p_td(f"let assemblyFormat = {emitter_td.quote(assembly_format)};")
else:
assembly_operands = " `,` ".join("$" + arg["name"]
for arg in operator.arguments)
assembly_operand_types = " `,` ".join(
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments)
if operator.is_varret:
assembly_result_types = "qualified(type($results))"
else:
assembly_result_types = " `,` ".join(
f"""qualified(type(${ret["name"] or generic_result_name(e)}))"""
for e, ret in enumerate(operator.returns))
if assembly_operand_types and assembly_result_types:
maybe_arrow = " `->` "
else:
maybe_arrow = ""
assembly_format = f"{assembly_operands} attr-dict `:` {assembly_operand_types}{maybe_arrow}{assembly_result_types}"
p(f"let assemblyFormat = {emitter.quote(assembly_format)};")
p_td(f"let hasCustomAssemblyFormat = 1;")
p_td(f"""let extraClassDefinition = [{{
ParseResult {cpp_class_name}::parse(OpAsmParser &parser, OperationState &result) {{
return parseDefaultTorchOp(parser, result, {len(operator.arguments)}, {len(operator.returns)});
}}
void {cpp_class_name}::print(OpAsmPrinter &printer) {{
printDefaultTorchOp(printer, *this, {len(operator.arguments)}, {len(operator.returns)});
}}
}}];
""")
if has_folder:
p("let hasFolder = 1;")
p_td("let hasFolder = 1;")
if has_canonicalizer:
p("let hasCanonicalizer = 1;")
p("}")
p("\n")
p_td("let hasCanonicalizer = 1;")
p_td("}")
p_td("\n")
def emit_op(operator: JitOperator,
f: TextIO,
emitter_td: TextEmitter,
*,
traits: Optional[List[str]] = None,
has_folder: bool = False,
@ -185,364 +202,354 @@ def emit_op(operator: JitOperator,
traits += ["ReadOnly"]
raw_emit_op(operator,
f,
emitter_td,
traits=traits,
has_folder=has_folder,
has_canonicalizer=has_canonicalizer)
def emit_prim_ops(torch_ir_dir: str, registry: Registry):
td_file = os.path.join(torch_ir_dir, "GeneratedPrimOps.td")
with open(td_file, "w") as f:
f.write(ODS_BANNER)
def emit_ops(emitter_td: TextEmitter, registry: Registry):
def emit(key, **kwargs):
emit_op(registry[key], emitter_td, **kwargs)
def emit(key, **kwargs):
emit_op(registry[key], f, **kwargs)
def emit_with_mutating_variants(key, **kwargs):
operator = registry[key]
emit_op(operator, emitter_td, **kwargs)
ns, unqual, overload = operator.triple
emit_op(registry.get_by_triple((ns, unqual + "_", overload)),
emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"])
emit("prim::layout : (Tensor) -> (int)")
emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
emit("prim::device : (Tensor) -> (Device)")
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
emit("prim::min.int : (int, int) -> (int)")
emit("prim::max.self_int : (int[]) -> (int)")
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
emit("prim::RaiseException : (str, str?) -> ()")
emit("prim::Uninitialized : () -> (Any)",
has_canonicalizer=True, traits=["NoSideEffect"])
emit("prim::unchecked_cast : (t) -> (t)", has_folder=True,
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
emit("prim::Print : (...) -> ()")
emit("prim::tolist : (...) -> (...)")
# ==========================================================================
# `aten::` namespace.
# ==========================================================================
# Elementwise tensor compute ops
for key in [
"aten::tanh : (Tensor) -> (Tensor)",
"aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::relu : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::hardsigmoid : (Tensor) -> (Tensor)",
"aten::hardswish : (Tensor) -> (Tensor)",
"aten::erf : (Tensor) -> (Tensor)",
"aten::silu : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)",
"aten::exp : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::div.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
"aten::log2 : (Tensor) -> (Tensor)",
"aten::rsqrt : (Tensor) -> (Tensor)",
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# Note the deliberate lowercasing of the "t" for consistency with all
# the name munging. This is not load bearing, but is convenient for
# consistency.
td_file = os.path.join(torch_ir_dir, "GeneratedAtenOps.td")
with open(td_file, "w") as f:
f.write(ODS_BANNER)
]:
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
def emit(key, **kwargs):
emit_op(registry[key], f, **kwargs)
# Ops without value semantics but the corresponding without trailing
# underscore variant doesn't exist.
emit("aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)")
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
def emit_with_mutating_variants(key, **kwargs):
operator = registry[key]
emit_op(operator, f, **kwargs)
ns, unqual, overload = operator.triple
emit_op(registry.get_by_triple((ns, unqual + "_", overload)),
f,
traits=["IsTrailingUnderscoreInplaceVariant"])
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants(
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
# Elementwise tensor compute ops
for key in [
"aten::tanh : (Tensor) -> (Tensor)",
"aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::relu : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::hardsigmoid : (Tensor) -> (Tensor)",
"aten::hardswish : (Tensor) -> (Tensor)",
"aten::erf : (Tensor) -> (Tensor)",
"aten::silu : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)",
"aten::exp : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::div.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
"aten::log2 : (Tensor) -> (Tensor)",
"aten::rsqrt : (Tensor) -> (Tensor)",
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
# Non-elementwise tensor compute ops
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
)
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
emit(
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
)
emit(
"aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)"
)
emit(
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
)
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::sqrt : (Tensor) -> (Tensor)")
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
]:
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
emit("aten::all : (Tensor) -> (Tensor)")
emit("aten::any : (Tensor) -> (Tensor)")
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)")
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::max : (Tensor) -> (Tensor)")
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)")
emit("aten::cpu : (Tensor) -> (Tensor)")
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::t : (Tensor) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
# Ops without value semantics but the corresponding without trailing
# underscore variant doesn't exist.
emit("aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)")
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
# Dict ops.
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
emit("aten::__getitem__.Dict_str : (Dict(str, t), str) -> (t)", has_folder=True)
emit("aten::_set_item.str : (Dict(str, t), str, t) -> ()")
emit("aten::keys.str : (Dict(str, t)) -> (str[])")
emit("aten::get.default_str : (Dict(str, t), str, t) -> (t)")
emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
# List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])")
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
emit("aten::list.t : (t[]) -> (t[])")
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])")
emit("aten::insert.t : (t[], int, t) -> ()")
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
# Non-elementwise tensor compute ops
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
)
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit (
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
emit(
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
)
emit(
"aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)"
)
emit(
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
)
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::sqrt : (Tensor) -> (Tensor)")
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit ("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
# Str ops.
emit("aten::add.str : (str, str) -> (str)")
emit("aten::eq.str : (str, str) -> (bool)", has_folder=True)
emit("aten::str : (t) -> (str)")
emit("aten::format : (...) -> (str)")
emit("aten::join : (str, str[]) -> (str)")
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
emit("aten::all : (Tensor) -> (Tensor)")
emit("aten::any : (Tensor) -> (Tensor)")
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)")
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::max : (Tensor) -> (Tensor)")
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)")
emit("aten::cpu : (Tensor) -> (Tensor)")
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::t : (Tensor) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
# Type conversion ops.
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
emit("aten::Float.str : (str) -> (float)")
emit("aten::Int.float : (float) -> (int)")
# Dict ops.
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
emit("aten::__getitem__.Dict_str : (Dict(str, t), str) -> (t)", has_folder=True)
emit("aten::_set_item.str : (Dict(str, t), str, t) -> ()")
emit("aten::keys.str : (Dict(str, t)) -> (str[])")
emit("aten::get.default_str : (Dict(str, t), str, t) -> (t)")
emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()")
# Primitive ops
emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True)
emit("aten::__derive_index : (int, int, int) -> (int)", has_folder=True)
emit("aten::gt.int : (int, int) -> (bool)", has_folder=True)
emit("aten::ge.int : (int, int) -> (bool)", has_folder=True)
emit("aten::lt.int : (int, int) -> (bool)", has_folder=True)
emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)")
emit("aten::mul.float : (float, float) -> (float)")
emit("aten::neg.float : (float) -> (float)")
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float_int : (float, int) -> (bool)")
emit("aten::__and__.bool : (bool, bool) -> (bool)")
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
emit("aten::len.t : (t[]) -> (int)",
has_folder=True,
has_canonicalizer=True)
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::eq.device : (Device, Device) -> (bool)")
# List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])")
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
emit("aten::list.t : (t[]) -> (t[])")
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])")
emit("aten::insert.t : (t[], int, t) -> ()")
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
# Str ops.
emit("aten::add.str : (str, str) -> (str)")
emit("aten::eq.str : (str, str) -> (bool)", has_folder=True)
emit("aten::str : (t) -> (str)")
emit("aten::format : (...) -> (str)")
emit("aten::join : (str, str[]) -> (str)")
# ==========================================================================
# `prim::` namespace.
# ==========================================================================
# Type conversion ops.
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
emit("aten::Float.str : (str) -> (float)")
emit("aten::Int.float : (float) -> (int)")
emit("prim::layout : (Tensor) -> (int)")
emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
emit("prim::device : (Tensor) -> (Device)")
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
emit("prim::min.int : (int, int) -> (int)")
emit("prim::max.self_int : (int[]) -> (int)")
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
emit("prim::RaiseException : (str, str?) -> ()")
emit("prim::Uninitialized : () -> (Any)",
has_canonicalizer=True, traits=["NoSideEffect"])
emit("prim::unchecked_cast : (t) -> (t)", has_folder=True,
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
emit("prim::Print : (...) -> ()")
emit("prim::tolist : (...) -> (...)")
# Primitive ops
emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True)
emit("aten::__derive_index : (int, int, int) -> (int)", has_folder=True)
emit("aten::gt.int : (int, int) -> (bool)", has_folder=True)
emit("aten::ge.int : (int, int) -> (bool)", has_folder=True)
emit("aten::lt.int : (int, int) -> (bool)", has_folder=True)
emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)")
emit("aten::mul.float : (float, float) -> (float)")
emit("aten::neg.float : (float) -> (float)")
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::lt.float_int : (float, int) -> (bool)")
emit("aten::__and__.bool : (bool, bool) -> (bool)")
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
emit("aten::len.t : (t[]) -> (int)",
has_folder=True,
has_canonicalizer=True)
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::eq.device : (Device, Device) -> (bool)")
# ==========================================================================
# `quantized::` namespace.
# ==========================================================================
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
with open(td_file, "w") as f:
f.write(ODS_BANNER)
def emit(key, **kwargs):
emit_op(registry[key], f, **kwargs)
emit(
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
traits=["HasValueSemantics"])
emit(
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
traits=["HasValueSemantics"])
def dump_registered_ops(outfile: TextIO, registry: Registry):
for _, v in sorted(registry.by_unique_key.items()):
outfile.write(repr(v))
def main(args: argparse.Namespace):
registry = Registry.load()
if args.debug_registry_dump:
with open(args.debug_registry_dump, "w") as debug_registry_dump:
dump_registered_ops(debug_registry_dump, registry)
emit_prim_ops(args.torch_ir_dir, registry)
emit_aten_ops(args.torch_ir_dir, registry)
emit_quantized_ops(args.torch_ir_dir, registry)
td_path = os.path.join(args.torch_ir_include_dir, "GeneratedTorchOps.td")
with open(td_path, "w") as f_td:
emitter_td = TextEmitter(f_td)
emitter_td.print(ODS_BANNER)
emit_ops(emitter_td, registry)
def _create_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="generate_ods")
parser.add_argument(
"--torch_ir_dir",
"--torch_ir_include_dir",
required=True,
help="Directory containing the Torch dialect definition")
help="Directory in include/ containing the Torch dialect")
parser.add_argument(
"--debug_registry_dump",
help="File to dump the the PyTorch JIT operator registry into")