mirror of https://github.com/llvm/torch-mlir
Reduce compilation time for TorchOps.cpp.inc
The `assemblyFormat` stuff (which generates unrolled, per-op C++ code) was taking up a lot of compile time, and all the ops are essentially printed with the same logic. So this PR makes them all call the same helper function. This is done by using `let hasCustomAssemblyFormat = 1` and then implementing `FooOp::parse` and `FooOp::print`. Additionally, the `Generated*Ops.td` files are all collapsed into just `GeneratedTorchOps.td` (there is no reason to have the files separate, since the files are very large anyway so one is always having to search within them -- editors don't care that the file to search is now a bit bigger :) ). This reduces TorchOpsODSGenerated.cpp compile time (which is now GeneratedTorchOps.cpp) from 39 to 31 seconds on my machine. This is actually less than I expected, but this PR is an overall cleanup to the code anyway. The next step will be to introduce (better) functionality upstream for sharding the TorchOps.cpp.inc file, so that we can truly parallelize the O(#ops) costs. This is also necessary, because after this PR, TorchDialect.cpp is now the slowest file to compile, due to the `addOperations<... all the ops ...>` call, which needs to be shareded too.pull/605/head snapshot-20220321.338
parent
5b9bdfaf3f
commit
729402c3f4
|
@ -1,14 +1,14 @@
|
|||
#!/bin/bash
|
||||
# Updates auto-generated ODS files for the `torch` dialect.
|
||||
set -e
|
||||
set -euo pipefail
|
||||
|
||||
src_dir="$(realpath $(dirname $0)/..)"
|
||||
build_dir="$(realpath "${TORCH_MLIR_BUILD_DIR:-$src_dir/build}")"
|
||||
torch_ir_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR"
|
||||
torch_ir_include_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR"
|
||||
python_packages_dir="${build_dir}/tools/torch-mlir/python_packages"
|
||||
|
||||
#ninja -C "${build_dir}"
|
||||
PYTHONPATH="${python_packages_dir}/torch_mlir" python \
|
||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
|
||||
--torch_ir_dir="${torch_ir_dir}" \
|
||||
--debug_registry_dump="${torch_ir_dir}/JITOperatorRegistryDump.txt"
|
||||
--torch_ir_include_dir="${torch_ir_include_dir}" \
|
||||
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"
|
||||
|
|
|
@ -1,249 +0,0 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
// Operation summaries and descriptions were systematically derived from public
|
||||
// API docstrings and are licensed accordingly:
|
||||
// https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file is automatically generated. Please do not edit.
|
||||
// Generated via:
|
||||
// python -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::layout : (Tensor) -> (int)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::TupleIndex : (Any, int) -> (Any)`";
|
||||
let arguments = (ins
|
||||
AnyTorchType:$tup,
|
||||
Torch_IntType:$i
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = "$tup `,` $i attr-dict `:` qualified(type($tup)) `,` qualified(type($i)) `->` qualified(type($result))";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::device : (Tensor) -> (Device)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_DeviceType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::dtype : (Tensor) -> (int)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimTupleUnpackOp : Torch_Op<"prim.TupleUnpack", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::TupleUnpack : (Any) -> (...)`";
|
||||
let arguments = (ins
|
||||
AnyTorchType:$tup
|
||||
);
|
||||
let results = (outs
|
||||
Variadic<AnyTorchType>:$results
|
||||
);
|
||||
let assemblyFormat = "$tup attr-dict `:` qualified(type($tup)) `->` qualified(type($results))";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::NumToTensor.Scalar : (Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$a
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::min.self_int : (int[]) -> (int)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::min.int : (int, int) -> (int)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$a,
|
||||
Torch_IntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::max.self_int : (int[]) -> (int)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimMaxIntOp : Torch_Op<"prim.max.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::max.int : (int, int) -> (int)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$a,
|
||||
Torch_IntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::RaiseException : (str, str?) -> ()`";
|
||||
let arguments = (ins
|
||||
Torch_StringType:$msg,
|
||||
TorchOptionalStringType:$cls
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
let assemblyFormat = "$msg `,` $cls attr-dict `:` qualified(type($msg)) `,` qualified(type($cls))";
|
||||
}
|
||||
|
||||
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
|
||||
NoSideEffect,
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::Uninitialized : () -> (Any)`";
|
||||
let arguments = (ins
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = " attr-dict `:` qualified(type($result))";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::unchecked_cast : (t) -> (t)`";
|
||||
let arguments = (ins
|
||||
AnyTorchType:$x
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = "$x attr-dict `:` qualified(type($x)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimPrintOp : Torch_Op<"prim.Print", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::Print : (...) -> ()`";
|
||||
let arguments = (ins
|
||||
Variadic<AnyTorchType>:$operands
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands))";
|
||||
}
|
||||
|
||||
def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prim::tolist : (...) -> (...)`";
|
||||
let arguments = (ins
|
||||
Variadic<AnyTorchType>:$operands
|
||||
);
|
||||
let results = (outs
|
||||
Variadic<AnyTorchType>:$results
|
||||
);
|
||||
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
|
||||
}
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
// Operation summaries and descriptions were systematically derived from public
|
||||
// API docstrings and are licensed accordingly:
|
||||
// https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file is automatically generated. Please do not edit.
|
||||
// Generated via:
|
||||
// python -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||
HasValueSemantics,
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$X,
|
||||
Torch_LinearParamsType:$W_prepack,
|
||||
Torch_FloatType:$Y_scale_i,
|
||||
Torch_IntType:$Y_zero_point_i
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$Y
|
||||
);
|
||||
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` qualified(type($X)) `,` qualified(type($W_prepack)) `,` qualified(type($Y_scale_i)) `,` qualified(type($Y_zero_point_i)) `->` qualified(type($Y))";
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -22,9 +22,7 @@ class Torch_Op<string mnemonic, list<Trait> traits = []>
|
|||
: Op<Torch_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
include "torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TorchScript `torch.nn.Module` object instantiation ops.
|
||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_library(TorchMLIRTorchDialect
|
|||
TorchOps.cpp
|
||||
TorchOpsODSGenerated.cpp
|
||||
TorchTypes.cpp
|
||||
UtilsForODSGenerated.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
#include "UtilsForODSGenerated.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains utilities referenced by ODS generated code.
|
||||
//
|
||||
// The utilities defined here are only meant for use by ODS generated code.
|
||||
// If something is of wider use, then it should be moved elsewhere.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "UtilsForODSGenerated.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
|
||||
OperationState &result, int numOperands,
|
||||
int numResults) {
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
SmallVector<OpAsmParser::OperandType> operands;
|
||||
if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
if (parser.parseColon())
|
||||
return failure();
|
||||
if (numOperands > 0) {
|
||||
SmallVector<Type> operandTypes;
|
||||
if (parser.parseTypeList(operandTypes))
|
||||
return failure();
|
||||
if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
|
||||
return failure();
|
||||
}
|
||||
if (numOperands > 0 && numResults > 0) {
|
||||
if (parser.parseArrow())
|
||||
return failure();
|
||||
}
|
||||
if (numResults > 0) {
|
||||
if (parser.parseTypeList(result.types))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
|
||||
int numResults) {
|
||||
p << ' ';
|
||||
llvm::interleaveComma(op->getOperands(), p);
|
||||
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
|
||||
p << " : ";
|
||||
if (numOperands > 0) {
|
||||
p << ' ';
|
||||
llvm::interleaveComma(op->getOperandTypes(), p);
|
||||
}
|
||||
if (numOperands > 0 && numResults > 0) {
|
||||
p << " -> ";
|
||||
}
|
||||
if (numResults > 0) {
|
||||
p << ' ';
|
||||
llvm::interleaveComma(op->getResultTypes(), p);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains utilities referenced by GeneratedTorchOps.cpp.
|
||||
//
|
||||
// The utilities defined here are only meant for use by GeneratedTorchOps.cpp.
|
||||
// If something is of wider use, then it should be moved elsewhere.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
// Parse a generated Torch op in the default format.
|
||||
ParseResult parseDefaultTorchOp(OpAsmParser &parser, OperationState &result,
|
||||
int numOperands, int numResults);
|
||||
|
||||
// Print a generated Torch op in the default format.
|
||||
void printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
|
||||
int numResults);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
|
@ -136,9 +136,9 @@ class JitOperator:
|
|||
for op_name_atom in op_name_atoms:
|
||||
for s in op_name_atom.split("_"):
|
||||
op_class_name_atoms.append(s if s else "_")
|
||||
td_def_name = "Torch_" + "".join(
|
||||
cpp_class_name = "".join(
|
||||
uppercase_first_letter(s) for s in op_class_name_atoms) + "Op"
|
||||
return op_name, td_def_name
|
||||
return op_name, cpp_class_name
|
||||
|
||||
def get_shape_function_signature(self):
|
||||
"""Gets the Python function signature for this op's shape function.
|
||||
|
@ -197,9 +197,9 @@ class JitOperator:
|
|||
|
||||
# Emit the MLIR names to allow easy reverse lookup if starting
|
||||
# from an unregistered op.
|
||||
op_name, td_def_name = self.get_mlir_names()
|
||||
op_name, cpp_class_name = self.get_mlir_names()
|
||||
p(f"MLIR op name = torch.{op_name}")
|
||||
p(f"MLIR td def name = {td_def_name}")
|
||||
p(f"MLIR cpp class name = {cpp_class_name}")
|
||||
|
||||
p(f"namespace = {self.namespace}")
|
||||
p(f"unqualified_name = {self.unqualified_name}")
|
||||
|
|
|
@ -65,30 +65,33 @@ def _get_main_module_name() -> str:
|
|||
# pytype: enable=attribute-error
|
||||
|
||||
|
||||
ODS_BANNER = "\n".join([
|
||||
"//===-------------------------------------------------------*- tablegen -*-===//",
|
||||
"//",
|
||||
"// This file is licensed under the Apache License v2.0 with LLVM Exceptions.",
|
||||
"// See https://llvm.org/LICENSE.txt for license information.",
|
||||
"// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception",
|
||||
"// Also available under a BSD-style license. See LICENSE.",
|
||||
"//",
|
||||
"// Operation summaries and descriptions were systematically derived from public",
|
||||
"// API docstrings and are licensed accordingly:",
|
||||
"// https://github.com/pytorch/pytorch/blob/master/LICENSE",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"//",
|
||||
"// This file is automatically generated. Please do not edit.",
|
||||
"// Generated via:",
|
||||
f"// python -m {_get_main_module_name()}",
|
||||
"//",
|
||||
"//===----------------------------------------------------------------------===//",
|
||||
"",
|
||||
"",
|
||||
])
|
||||
ODS_BANNER = f"""//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
// Operation summaries and descriptions were systematically derived from public
|
||||
// API docstrings and are licensed accordingly:
|
||||
// https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file is automatically generated. Please do not edit.
|
||||
// Generated via:
|
||||
// ```
|
||||
// python -m {_get_main_module_name()}
|
||||
// ```
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
||||
"""
|
||||
|
||||
|
||||
def raw_emit_op(operator: JitOperator,
|
||||
emitter_td: TextEmitter,
|
||||
*, traits: List[str],
|
||||
has_folder: bool, has_canonicalizer: bool):
|
||||
"""Emit the ODS for a JitOperator to a textual file.
|
||||
|
||||
|
@ -98,43 +101,46 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
|||
|
||||
You probably don't want to call this directly.
|
||||
"""
|
||||
emitter = TextEmitter(f)
|
||||
p = lambda *args: emitter.print(*args)
|
||||
op_name, td_def_name = operator.get_mlir_names()
|
||||
p_td = lambda *args: emitter_td.print(*args)
|
||||
op_name, cpp_class_name = operator.get_mlir_names()
|
||||
|
||||
# Generate unique result names for ops with nameless results
|
||||
multiple_results = len(operator.returns) > 1
|
||||
generic_result_name = lambda i: "result" + (str(i) if multiple_results else "")
|
||||
|
||||
p(f"def {td_def_name} : Torch_Op<{emitter.quote(op_name)}, [")
|
||||
with emitter.indent():
|
||||
with emitter.indent():
|
||||
p(",\n".join(traits))
|
||||
p("]> {")
|
||||
with emitter.indent():
|
||||
def generic_result_name(i):
|
||||
return "result" + (str(i) if multiple_results else "")
|
||||
|
||||
p_td(
|
||||
f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [")
|
||||
with emitter_td.indent():
|
||||
with emitter_td.indent():
|
||||
p_td(",\n".join(traits))
|
||||
p_td("]> {")
|
||||
with emitter_td.indent():
|
||||
summary = f"Generated op for `{operator.unique_key}`"
|
||||
p(f"let summary = {emitter.quote(summary)};")
|
||||
p(f"let arguments = (ins")
|
||||
with emitter.indent():
|
||||
p_td(f"let summary = {emitter_td.quote(summary)};")
|
||||
p_td(f"let arguments = (ins")
|
||||
with emitter_td.indent():
|
||||
if operator.is_vararg:
|
||||
p("Variadic<AnyTorchType>:$operands")
|
||||
p_td("Variadic<AnyTorchType>:$operands")
|
||||
else:
|
||||
p(",\n".join([
|
||||
p_td(",\n".join([
|
||||
f"""{get_ods_type(arg["type"])}:${arg["name"]}"""
|
||||
for arg in operator.arguments
|
||||
]))
|
||||
p(");")
|
||||
p(f"let results = (outs")
|
||||
with emitter.indent():
|
||||
p_td(");")
|
||||
p_td(f"let results = (outs")
|
||||
with emitter_td.indent():
|
||||
if operator.is_varret:
|
||||
p("Variadic<AnyTorchType>:$results")
|
||||
p_td("Variadic<AnyTorchType>:$results")
|
||||
else:
|
||||
p(",\n".join([
|
||||
p_td(",\n".join([
|
||||
f"""{get_ods_type(ret["type"])}:${ret["name"] or generic_result_name(e)}"""
|
||||
for e, ret in enumerate(operator.returns)
|
||||
]))
|
||||
p(");")
|
||||
p_td(");")
|
||||
|
||||
if operator.is_vararg or operator.is_varret:
|
||||
if operator.is_vararg:
|
||||
assembly_operands = "`(` $operands `)`"
|
||||
assembly_operand_types = "qualified(type($operands))"
|
||||
|
@ -154,17 +160,28 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
|||
else:
|
||||
maybe_arrow = ""
|
||||
assembly_format = f"{assembly_operands} attr-dict `:` {assembly_operand_types}{maybe_arrow}{assembly_result_types}"
|
||||
p(f"let assemblyFormat = {emitter.quote(assembly_format)};")
|
||||
p_td(f"let assemblyFormat = {emitter_td.quote(assembly_format)};")
|
||||
else:
|
||||
p_td(f"let hasCustomAssemblyFormat = 1;")
|
||||
p_td(f"""let extraClassDefinition = [{{
|
||||
ParseResult {cpp_class_name}::parse(OpAsmParser &parser, OperationState &result) {{
|
||||
return parseDefaultTorchOp(parser, result, {len(operator.arguments)}, {len(operator.returns)});
|
||||
}}
|
||||
void {cpp_class_name}::print(OpAsmPrinter &printer) {{
|
||||
printDefaultTorchOp(printer, *this, {len(operator.arguments)}, {len(operator.returns)});
|
||||
}}
|
||||
}}];
|
||||
""")
|
||||
if has_folder:
|
||||
p("let hasFolder = 1;")
|
||||
p_td("let hasFolder = 1;")
|
||||
if has_canonicalizer:
|
||||
p("let hasCanonicalizer = 1;")
|
||||
p("}")
|
||||
p("\n")
|
||||
p_td("let hasCanonicalizer = 1;")
|
||||
p_td("}")
|
||||
p_td("\n")
|
||||
|
||||
|
||||
def emit_op(operator: JitOperator,
|
||||
f: TextIO,
|
||||
emitter_td: TextEmitter,
|
||||
*,
|
||||
traits: Optional[List[str]] = None,
|
||||
has_folder: bool = False,
|
||||
|
@ -185,58 +202,28 @@ def emit_op(operator: JitOperator,
|
|||
traits += ["ReadOnly"]
|
||||
|
||||
raw_emit_op(operator,
|
||||
f,
|
||||
emitter_td,
|
||||
traits=traits,
|
||||
has_folder=has_folder,
|
||||
has_canonicalizer=has_canonicalizer)
|
||||
|
||||
|
||||
def emit_prim_ops(torch_ir_dir: str, registry: Registry):
|
||||
td_file = os.path.join(torch_ir_dir, "GeneratedPrimOps.td")
|
||||
with open(td_file, "w") as f:
|
||||
f.write(ODS_BANNER)
|
||||
|
||||
def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||
def emit(key, **kwargs):
|
||||
emit_op(registry[key], f, **kwargs)
|
||||
|
||||
emit("prim::layout : (Tensor) -> (int)")
|
||||
emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
|
||||
emit("prim::device : (Tensor) -> (Device)")
|
||||
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
|
||||
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
||||
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
|
||||
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
|
||||
emit("prim::min.int : (int, int) -> (int)")
|
||||
emit("prim::max.self_int : (int[]) -> (int)")
|
||||
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("prim::RaiseException : (str, str?) -> ()")
|
||||
emit("prim::Uninitialized : () -> (Any)",
|
||||
has_canonicalizer=True, traits=["NoSideEffect"])
|
||||
emit("prim::unchecked_cast : (t) -> (t)", has_folder=True,
|
||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
||||
emit("prim::Print : (...) -> ()")
|
||||
emit("prim::tolist : (...) -> (...)")
|
||||
|
||||
|
||||
def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||
# Note the deliberate lowercasing of the "t" for consistency with all
|
||||
# the name munging. This is not load bearing, but is convenient for
|
||||
# consistency.
|
||||
td_file = os.path.join(torch_ir_dir, "GeneratedAtenOps.td")
|
||||
with open(td_file, "w") as f:
|
||||
f.write(ODS_BANNER)
|
||||
|
||||
def emit(key, **kwargs):
|
||||
emit_op(registry[key], f, **kwargs)
|
||||
emit_op(registry[key], emitter_td, **kwargs)
|
||||
|
||||
def emit_with_mutating_variants(key, **kwargs):
|
||||
operator = registry[key]
|
||||
emit_op(operator, f, **kwargs)
|
||||
emit_op(operator, emitter_td, **kwargs)
|
||||
ns, unqual, overload = operator.triple
|
||||
emit_op(registry.get_by_triple((ns, unqual + "_", overload)),
|
||||
f,
|
||||
emitter_td,
|
||||
traits=["IsTrailingUnderscoreInplaceVariant"])
|
||||
|
||||
# ==========================================================================
|
||||
# `aten::` namespace.
|
||||
# ==========================================================================
|
||||
|
||||
# Elementwise tensor compute ops
|
||||
for key in [
|
||||
"aten::tanh : (Tensor) -> (Tensor)",
|
||||
|
@ -309,7 +296,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
|
||||
# Non-elementwise tensor compute ops
|
||||
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
||||
|
@ -328,7 +316,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit (
|
||||
emit(
|
||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
|
@ -360,7 +348,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
emit ("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
|
@ -508,15 +496,31 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
|
||||
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
|
||||
# ==========================================================================
|
||||
# `prim::` namespace.
|
||||
# ==========================================================================
|
||||
|
||||
emit("prim::layout : (Tensor) -> (int)")
|
||||
emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
|
||||
emit("prim::device : (Tensor) -> (Device)")
|
||||
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
|
||||
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
||||
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
|
||||
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
|
||||
emit("prim::min.int : (int, int) -> (int)")
|
||||
emit("prim::max.self_int : (int[]) -> (int)")
|
||||
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("prim::RaiseException : (str, str?) -> ()")
|
||||
emit("prim::Uninitialized : () -> (Any)",
|
||||
has_canonicalizer=True, traits=["NoSideEffect"])
|
||||
emit("prim::unchecked_cast : (t) -> (t)", has_folder=True,
|
||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
||||
emit("prim::Print : (...) -> ()")
|
||||
emit("prim::tolist : (...) -> (...)")
|
||||
|
||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
||||
with open(td_file, "w") as f:
|
||||
f.write(ODS_BANNER)
|
||||
|
||||
def emit(key, **kwargs):
|
||||
emit_op(registry[key], f, **kwargs)
|
||||
# ==========================================================================
|
||||
# `quantized::` namespace.
|
||||
# ==========================================================================
|
||||
|
||||
emit(
|
||||
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
|
||||
|
@ -527,22 +531,25 @@ def dump_registered_ops(outfile: TextIO, registry: Registry):
|
|||
for _, v in sorted(registry.by_unique_key.items()):
|
||||
outfile.write(repr(v))
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
registry = Registry.load()
|
||||
if args.debug_registry_dump:
|
||||
with open(args.debug_registry_dump, "w") as debug_registry_dump:
|
||||
dump_registered_ops(debug_registry_dump, registry)
|
||||
emit_prim_ops(args.torch_ir_dir, registry)
|
||||
emit_aten_ops(args.torch_ir_dir, registry)
|
||||
emit_quantized_ops(args.torch_ir_dir, registry)
|
||||
td_path = os.path.join(args.torch_ir_include_dir, "GeneratedTorchOps.td")
|
||||
with open(td_path, "w") as f_td:
|
||||
emitter_td = TextEmitter(f_td)
|
||||
emitter_td.print(ODS_BANNER)
|
||||
emit_ops(emitter_td, registry)
|
||||
|
||||
|
||||
def _create_argparse() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="generate_ods")
|
||||
parser.add_argument(
|
||||
"--torch_ir_dir",
|
||||
"--torch_ir_include_dir",
|
||||
required=True,
|
||||
help="Directory containing the Torch dialect definition")
|
||||
help="Directory in include/ containing the Torch dialect")
|
||||
parser.add_argument(
|
||||
"--debug_registry_dump",
|
||||
help="File to dump the the PyTorch JIT operator registry into")
|
||||
|
|
Loading…
Reference in New Issue