torch-mlir/build_tools/autogen_ltc_backend.py

355 lines
11 KiB
Python
Raw Normal View History

2022-03-24 22:15:43 +08:00
import argparse
import hashlib
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from shutil import which
from textwrap import dedent
import yaml
TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")
2022-03-24 22:15:43 +08:00
sys.path.append(str(TORCH_DIR))
2022-03-24 22:15:43 +08:00
# PyTorch's LTC backend autogen script
import torchgen.dest.lazy_ir
import torchgen.gen_lazy_tensor
from torchgen.api.lazy import LazyIrSchema
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.model import NativeFunctionsGroup
2022-03-24 22:15:43 +08:00
def isOptionalCType(arg):
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"
2022-03-24 22:15:43 +08:00
def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
print("Generating Native Functions Yaml")
native_path = TORCH_DIR.joinpath("aten", "src", "ATen", "native")
native_yaml_path = native_path.joinpath("native_functions.yaml")
tags_yaml_path = native_path.joinpath("tags.yaml")
2022-03-24 22:15:43 +08:00
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
2022-03-24 22:15:43 +08:00
native_functions = parsed_yaml.native_functions
grouped_native_functions = get_grouped_native_functions(native_functions)
def get_native_function_name(f):
func = f.func if hasattr(f, "func") else f.functional.func
return str(func.name)
aten_funcs = set(map(get_native_function_name, grouped_native_functions))
with config_path.open() as f:
config = yaml.load(f, yaml.CLoader)
# List of unsupported ops in LTC autogen because of some error
blacklist = config.get("blacklist", [])
# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported = config.get("supported", [])
# List of non-native ops to do IR codegen for
non_native = config.get("non_native", [])
2022-03-24 22:15:43 +08:00
if which("rg") is not None: # use ripgrep if available as its much faster
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
else:
cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"]
output = (
subprocess.check_output(
cmd + [str(torch_ops_file)],
encoding="utf-8",
)
.strip()
.split(os.linesep)
)
# process ops list
ops = []
supported_ops = []
skipped = []
for op in output:
op = op[6:]
opname = op.split(".")[0]
if opname in blacklist or op in blacklist:
continue
if opname in supported:
supported_ops.append(op)
continue
if op not in aten_funcs:
skipped.append(op)
continue
ops.append(op)
opnames = sorted(set(ops))
# Additional ops to support that are not supported by Torch-MLIR explicitly
supported_ops.extend(config.get("additional_ops", []))
with out_file.open("w") as f:
yaml.dump(
{
"backend": "Lazy",
"cpp_namespace": "torch::lazy",
2022-03-24 22:15:43 +08:00
"full_codegen": opnames,
"supported": sorted(supported_ops),
"non_native": non_native,
2022-03-24 22:15:43 +08:00
},
f,
default_flow_style=False,
)
f.write(
dedent(
"""
# Skipped ops (supported by Torch-MLIR but no equivalent native function)
"""
)
+ os.linesep.join(f"# - {op}" for op in sorted(skipped))
)
return parsed_yaml, grouped_native_functions
@dataclass(frozen=True)
class GenMlirLazyIr(torchgen.dest.GenLazyIR):
2022-03-24 22:15:43 +08:00
def lowering_function(self, schema, declaration_only=True):
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"
if declaration_only:
return f"{signature};"
2022-03-24 22:15:43 +08:00
emplace_arguments = []
for arg in schema.positional_args:
if arg.is_lazy_value:
if isOptionalCType(arg.lazy_type):
emplace_arguments.append(f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr")
continue
emplace_arguments.append('loctx->GetOutputOp(operand(i++))')
continue
emplace_arguments.append(f'"{arg.name}", {arg.name}')
emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments])
emplace_kwarg_values = [f'"{t.name}", loctx->GetOutputOp(operand(i++))' for t in schema.keyword_values]
emplace_kwarg_scalars = [f'"{t.name}", {t.name}' for t in schema.keyword_scalars]
emplace_kwarguments = "\n ".join(
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])
2022-03-24 22:15:43 +08:00
return f"""
{signature} {{
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
}}
""".strip()
2022-03-24 22:15:43 +08:00
def generate_backend(
source_yaml: Path,
backend_path: Path,
parsed_yaml: dict,
grouped_native_functions: list,
):
print("Running Lazy Tensor Autogen")
# No fallback code allowed
def gen_fallback_code(*args, **kwargs):
return ""
torchgen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
2022-03-24 22:15:43 +08:00
torchgen.gen_lazy_tensor.run_gen_lazy_tensor(
2022-03-24 22:15:43 +08:00
backend_name="TorchMlir",
aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")),
2022-03-24 22:15:43 +08:00
source_yaml=str(source_yaml),
output_dir=str(backend_path.joinpath("generated")),
2022-03-24 22:15:43 +08:00
dry_run=False,
impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")),
node_base="torch::lazy::TorchMlirNode",
2022-03-24 22:15:43 +08:00
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
shape_inference_hdr=str(backend_path.joinpath("LazyShapeInference.h")),
lazy_ir_generator=GenMlirLazyIr,
2022-03-24 22:15:43 +08:00
)
# Remove lazy_tensor_core imports
subprocess.check_call(
[
"sed",
"-i",
"/lazy_tensor_core/d",
str(backend_path.joinpath("generated", "LazyNativeFunctions.cpp")),
2022-03-24 22:15:43 +08:00
]
)
# programmatically check shape inference declarations
import re
sig_re = re.compile(
r"std::vector<Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
)
global_signatures = {}
def extract_signatures(path):
signatures = set()
for name, args in sig_re.findall(path.read_text()):
signature = re.sub(r"\s+", "", f"{name}({args})")
global_signatures[signature] = (name, args)
signatures.add(signature)
return signatures
upstream_shape_inference_decls = extract_signatures(
TORCH_DIR.joinpath("torch", "csrc", "lazy", "core", "shape_inference.h")
)
assert len(upstream_shape_inference_decls) > 0
shape_inference_decls = extract_signatures(
backend_path.joinpath("LazyShapeInference.h")
)
assert len(shape_inference_decls) > 0
shape_inference_defs = extract_signatures(
backend_path.joinpath("LazyShapeInference.cpp")
)
assert len(shape_inference_decls) > len(shape_inference_defs)
missing_defs = (
shape_inference_decls
- upstream_shape_inference_decls
- shape_inference_defs
)
if missing_defs:
backend_path.joinpath("generated", "GenLazyShapeInference.cpp").write_text(
2022-03-24 22:15:43 +08:00
dedent(
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel or shape definition
#include "../LazyShapeInference.h"
#include "../../utils/exception.h"
2022-03-24 22:15:43 +08:00
namespace torch {{
namespace lazy {{
{}
}} // namespace lazy
}} // namespace torch
"""
).format(
"".join(
dedent(
f"""
std::vector<Shape> {name}({args}) {{
UNIMPLEMENTED_FUNCTION_ERROR();
}}
"""
)
for name, args in map(
global_signatures.get, sorted(missing_defs)
)
)
)
)
unnecessary_defs = shape_inference_defs - shape_inference_decls
if unnecessary_defs:
unnecessary_defs = "\n\t".join(
f"{name}({args})"
for name, args in map(global_signatures.get, unnecessary_defs)
)
warnings.warn(
f"Unnecessary shape inference definitions found for:\n\t{unnecessary_defs}"
)
def main(args):
script_path = Path(__file__).resolve()
config_path = (
Path(__file__).resolve().parent.joinpath("autogen_ltc_backend.yaml")
)
torch_ops_file = TORCH_MLIR_DIR.joinpath(
"include",
"torch-mlir",
"Dialect",
"Torch",
"IR",
"GeneratedTorchOps.td",
)
assert torch_ops_file.exists()
native_functions = TORCH_MLIR_DIR.joinpath(
"generated_native_functions.yaml"
)
backend_path = TORCH_MLIR_DIR.joinpath(
"python", "torch_mlir", "csrc", "base_lazy_backend"
2022-03-24 22:15:43 +08:00
)
assert backend_path.is_dir()
prev_hash = None
hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash")
if hash_file.exists():
prev_hash = hash_file.read_text().strip()
m = hashlib.sha256()
m.update(script_path.read_bytes())
m.update(config_path.read_bytes())
m.update(torch_ops_file.read_bytes())
if native_functions.exists():
m.update(native_functions.read_bytes())
shape_inference_headers = backend_path.joinpath("LazyShapeInference.h")
if shape_inference_headers.exists():
m.update(shape_inference_headers.read_bytes())
shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp")
if shape_inference_defs.exists():
m.update(shape_inference_defs.read_bytes())
new_hash = m.hexdigest().strip()
if args.force or new_hash != prev_hash:
parsed_yaml, grouped_native_functions = generate_native_functions(
config_path, torch_ops_file, native_functions
)
generate_backend(
native_functions,
backend_path,
parsed_yaml,
grouped_native_functions,
)
hash_file.write_text(new_hash)
2022-03-24 22:15:43 +08:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-f",
"--force",
action="store_true",
)
main(parser.parse_args())