torch-mlir/build_tools/autogen_ltc_backend.py

533 lines
19 KiB
Python
Raw Permalink Normal View History

2022-03-24 22:15:43 +08:00
import argparse
import hashlib
import importlib.util
import logging
2022-03-24 22:15:43 +08:00
import os
2022-06-10 22:25:10 +08:00
import re
2022-03-24 22:15:43 +08:00
import subprocess
import warnings
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
from collections import defaultdict
2022-03-24 22:15:43 +08:00
from dataclasses import dataclass
from pathlib import Path
from shutil import which
2022-06-10 22:25:10 +08:00
from textwrap import dedent, indent
2022-03-24 22:15:43 +08:00
# PyTorch's LTC backend autogen script
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
import torchgen
import torchgen.dest.lazy_ir
import torchgen.gen_lazy_tensor
2022-06-10 22:25:10 +08:00
import yaml
from torchgen.api.lazy import LazyIrSchema, setValueT
from torchgen.api.types import BaseCppType
from torchgen.dest import GenLazyShapeInferenceDefinition
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
2022-06-10 22:25:10 +08:00
from torchgen.gen_backend_stubs import parse_backend_yaml
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
TORCH_DIR = Path(importlib.util.find_spec("torch").origin).resolve().parent.parent
TORCH_INCLUDE_DIR = TORCH_DIR.joinpath("torch", "include")
if not TORCH_INCLUDE_DIR.is_dir():
TORCH_INCLUDE_DIR = TORCH_DIR
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
2022-06-10 22:25:10 +08:00
def reindent(text, prefix=""):
return indent(dedent(text), prefix)
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
@dataclass(frozen=True)
class GenMlirLazyIr(torchgen.dest.GenLazyIR):
def isOptionalCType(self, arg):
return str(type(arg)) == "<class 'torchgen.api.types.types.OptionalCType'>"
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
def lowering_function(self, schema: LazyIrSchema):
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
if schema.properties.LowerDeclOnly:
return f"{signature};"
elif not schema.properties.Lower:
return ""
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
2022-06-10 22:25:10 +08:00
emplace_arguments = []
for arg in schema.positional_args:
if arg.is_lazy_value:
if self.isOptionalCType(arg.lazy_type):
emplace_arguments.append(
f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
)
else:
emplace_arguments.append("loctx->GetOutputOp(operand(i++))")
else:
emplace_arguments.append(f'"{arg.name}", {arg.name}')
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
emplace_arguments_str = "\n ".join(
f"arguments.emplace_back({a});" for a in emplace_arguments
)
emplace_kwarg_values = [
f'"{t.name}", loctx->GetOutputOp(operand(i++))'
for t in schema.keyword_values
]
emplace_kwarg_scalars = [
f'"{t.name}", {t.name}' for t in schema.keyword_scalars
]
emplace_kwarguments = "\n ".join(
f"kwarguments.emplace_back({a});"
for a in emplace_kwarg_values + emplace_kwarg_scalars
)
2022-03-24 22:15:43 +08:00
# Only create this variable if it's used to avoid Wunused-variable
operand_idx_counter = "size_t i = 0;" if "i++" in (emplace_arguments_str + emplace_kwarguments) else ""
2022-06-10 22:25:10 +08:00
return reindent(
f"""
{signature} {{
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
{operand_idx_counter}
2022-06-10 22:25:10 +08:00
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
2022-06-10 22:25:10 +08:00
return {schema.aten_name}_out;
}}
""",
" ",
)
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
class GenTorchMlirLTC:
def __init__(self, binary_dir):
2022-06-10 22:25:10 +08:00
self.script_path = Path(__file__).resolve()
self.config_path = (
Path(__file__).resolve().parent.joinpath("autogen_ltc_backend.yaml")
)
self.torch_ops_file = TORCH_MLIR_DIR.joinpath(
# fmt: off
"include", "torch-mlir", "Dialect", "Torch", "IR", "GeneratedTorchOps.td",
# fmt: on
2022-06-10 22:25:10 +08:00
)
assert self.torch_ops_file.exists()
self.binary_dir = Path(binary_dir)
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
self.backend_path = TORCH_MLIR_DIR.joinpath(
"projects", "ltc", "csrc", "base_lazy_backend"
2022-06-10 22:25:10 +08:00
)
Re-organize project structure to separate PyTorch dependencies from core project. (#2542) This is a first step towards the structure we discussed here: https://gist.github.com/stellaraccident/931b068aaf7fa56f34069426740ebf20 There are two primary goals: 1. Separate the core project (C++ dialects and conversions) from the hard PyTorch dependencies. We move all such things into projects/pt1 as a starting point since they are presently entangled with PT1-era APIs. Additional work can be done to disentangle components from that (specifically LTC is identified as likely ultimately living in a `projects/ltc`). 2. Create space for native PyTorch2 Dynamo-based infra to be upstreamed without needing to co-exist with the original TorchScript path. Very little changes in this path with respect to build layering or options. These can be updated in a followup without commingling directory structure changes. This also takes steps toward a couple of other layering enhancements: * Removes the llvm-external-projects/torch-mlir-dialects sub-project, collapsing it into the main tree. * Audits and fixes up the core C++ build to account for issues found while moving things. This is just an opportunistic pass through but roughly ~halves the number of build actions for the project from the high 4000's to the low 2000's. It deviates from the discussed plan by having a `projects/` tree instead of `compat/`. As I was thinking about it, this will better accommodate the follow-on code movement. Once things are roughly in place and the CI passing, followups will focus on more in-situ fixes and cleanups.
2023-11-03 10:45:55 +08:00
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
self.generated_path = self.binary_dir.joinpath(
"projects", "ltc", "csrc", "base_lazy_backend", "generated"
)
self.generated_path.mkdir(parents=True, exist_ok=True)
# Create symlink to match doc structure
generated_path = self.backend_path.joinpath("generated").resolve()
if not generated_path.exists():
generated_path.symlink_to(
os.path.relpath(self.generated_path, generated_path.parent),
target_is_directory=True,
)
2022-06-10 22:25:10 +08:00
self.tensor_class = "torch::lazy::LazyTensor"
# Set the lazy value class
setValueT(BaseCppType("torch::lazy", "Value"))
def calculate_hash(self):
m = hashlib.sha256()
# Add file contents to hash
for path in (
self.script_path,
self.config_path,
self.torch_ops_file,
self.source_yaml,
self.backend_path.joinpath("shape_inference.cpp"),
TORCHGEN_DIR.joinpath("dest", "lazy_ir.py"),
TORCHGEN_DIR.joinpath("api", "lazy.py"),
TORCHGEN_DIR.joinpath("model.py"),
2022-06-10 22:25:10 +08:00
):
if path.exists():
m.update(path.read_bytes())
return m.hexdigest().strip()
def generate_native_functions(self):
logging.info("Generating Native Functions Yaml")
2022-06-10 22:25:10 +08:00
native_path = TORCHGEN_DIR.joinpath("packaged", "ATen", "native")
2022-06-10 22:25:10 +08:00
native_yaml_path = native_path.joinpath("native_functions.yaml")
tags_yaml_path = native_path.joinpath("tags.yaml")
ts_native_yaml_path = TORCH_DIR.joinpath(
"aten", "src", "ATen", "native", "ts_native_functions.yaml"
)
ts_native_yaml = None
if ts_native_yaml_path.exists():
ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader)
else:
logging.warning(f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}")
2022-06-10 22:25:10 +08:00
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
self.native_functions = parsed_yaml.native_functions
self.backend_indices = parsed_yaml.backend_indices
self.grouped_native_functions = get_grouped_native_functions(
self.native_functions
)
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
def get_native_function_name(f):
func = f if hasattr(f, "func") else f.functional
return str(func.func.name)
2022-06-10 22:25:10 +08:00
self.native_functions = {
get_native_function_name(f): f for f in self.native_functions
}
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
def get_opnames(ops):
opnames = defaultdict(set)
for op in ops:
opname = op.split(".")[0]
opnames[opname].add(op)
return opnames
aten_funcs = get_opnames(
map(get_native_function_name, self.grouped_native_functions)
2022-03-24 22:15:43 +08:00
)
2022-06-10 22:25:10 +08:00
with self.config_path.open() as f:
config = yaml.load(f, yaml.CLoader)
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
# List of unsupported ops in LTC autogen because of some error
blacklist = set(config.get("blacklist", []))
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported = set(config.get("supported", []))
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
# List of non-native ops to do IR codegen for
non_native = config.get("non_native", [])
# use ripgrep if available as its much faster
if which("rg") is not None:
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
else:
cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"]
torch_ops = set(
op[6:]
for op in subprocess.check_output(
cmd + [str(self.torch_ops_file)],
encoding="utf-8",
2022-03-24 22:15:43 +08:00
)
2022-06-10 22:25:10 +08:00
.strip()
.split(os.linesep)
2022-03-24 22:15:43 +08:00
)
2022-06-10 22:25:10 +08:00
torch_opnames = get_opnames(torch_ops)
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
# process ops list
ops = set()
composite_implicit = set()
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
for op in torch_ops:
if op not in self.native_functions:
continue
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
func = self.native_functions[op]
base = func.func.name.name.base
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
if base in blacklist or op in blacklist:
continue
if base in supported or op in supported:
continue
# Blacklist new_/_like ops since they are non-differentiable.
if any(o.startswith("new_") or o.endswith("_like") for o in (base, op)):
continue
2022-06-10 22:25:10 +08:00
if func.has_composite_implicit_autograd_kernel:
composite_implicit.add(op)
elif func.func.name.name.inplace:
for autogen in func.autogen:
if "functional" in autogen.overload_name:
ops.add(str(autogen))
else:
ops.add(op)
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
skipped = set(torch_ops) - ops - supported - composite_implicit
# List of ops autogen even if not explicitly supported by Torch-MLIR explicitly
ops |= set(config.get("whitelist", []))
2022-06-10 22:25:10 +08:00
# Additional ops to support that are not supported by Torch-MLIR explicitly
supported |= set(config.get("additional_ops", []))
2022-03-24 22:15:43 +08:00
# List of ops that will take in symints for its size
symint = set(config.get("symint", []))
2022-06-10 22:25:10 +08:00
self.ops = sorted(ops)
with self.source_yaml.open("w") as f:
source_yaml = {
"backend": "Lazy",
"cpp_namespace": "torch::lazy",
"full_codegen": self.ops,
"supported": sorted(supported),
"symint": sorted(symint),
2022-06-10 22:25:10 +08:00
"non_native": non_native,
}
yaml.dump(source_yaml, f, default_flow_style=False)
f.write(
dedent(
"""
# Composite implicit ops (supported by Torch-MLIR but not differentiable)
{composite_implicit}
# Skipped ops (supported by Torch-MLIR but no equivalent native function)
{skipped}
"""
).format(
composite_implicit=os.linesep.join(
f"# - {op}" for op in sorted(composite_implicit)
),
skipped=os.linesep.join(f"# - {op}" for op in sorted(skipped)),
)
)
2022-03-24 22:15:43 +08:00
if ts_native_yaml:
ts_full_codegen = set(ts_native_yaml["full_codegen"])
ts_supported = set(ts_native_yaml["supported"])
mlir_full_codegen = set(self.ops)
if ts_full_codegen - mlir_full_codegen:
logging.debug(
"Full Codegen ops supported by the TorchScript backend "
"but not by the Torch-MLIR backend:\n {}".format(
"\n ".join(sorted(ts_full_codegen - mlir_full_codegen))
)
)
if mlir_full_codegen - ts_full_codegen:
logging.debug(
"Full Codegen ops supported by the Torch-MLIR backend "
"but not by the TorchScript backend:\n {}".format(
"\n ".join(sorted(mlir_full_codegen - ts_full_codegen))
)
)
if ts_supported - supported:
logging.debug(
"Ops supported by the TorchScript backend "
"but not by the Torch-MLIR backend:\n {}".format(
"\n ".join(sorted(ts_supported - supported))
)
)
if supported - ts_supported:
logging.debug(
"Ops supported by the Torch-MLIR backend "
"but not by the TorchScript backend:\n {}".format(
"\n ".join(sorted(supported - ts_supported))
)
)
2022-06-10 22:25:10 +08:00
def generate_shape_inference(self):
parsed_backend_yaml = parse_backend_yaml(
self.source_yaml,
self.grouped_native_functions,
self.backend_indices,
)
backend_index = self.backend_indices[parsed_backend_yaml.backend_key]
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
shape_gen = GenLazyShapeInferenceDefinition(backend_index, self.tensor_class)
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
sig_re = re.compile(
r"std::vector<torch::lazy::Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
)
global_signatures = {}
def extract_signatures(text):
signatures = set()
for name, args in sig_re.findall(text):
signature = re.sub(r"\s+", "", f"{name}({args})")
global_signatures[signature] = (name, args)
signatures.add(signature)
return signatures
shape_inference_decls = []
for op in self.ops:
f = self.native_functions[op]
shape_sig = shape_gen(f)
shape_inference_decls.extend(shape_sig)
self.generated_path.joinpath("shape_inference.h").write_text(
2022-03-24 22:15:43 +08:00
dedent(
"""
2022-06-10 22:25:10 +08:00
// This file contains autogenerated Lazy Shape Inference declarations
2022-03-24 22:15:43 +08:00
// for ops that dont have a corresponding structured kernel or shape definition
2022-06-10 22:25:10 +08:00
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Optional.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <vector>
2022-03-24 22:15:43 +08:00
namespace torch {{
namespace lazy {{
2022-06-10 22:25:10 +08:00
2022-03-24 22:15:43 +08:00
{}
2022-06-10 22:25:10 +08:00
2022-03-24 22:15:43 +08:00
}} // namespace lazy
}} // namespace torch
"""
2022-06-10 22:25:10 +08:00
).format(os.linesep.join(sorted(shape_inference_decls)))
)
shape_inference_decls = extract_signatures(
self.generated_path.joinpath("shape_inference.h").read_text()
2022-06-10 22:25:10 +08:00
)
assert len(shape_inference_decls) > 0
upstream_shape_inference_decls = extract_signatures(
TORCH_INCLUDE_DIR.joinpath(
2022-06-10 22:25:10 +08:00
"torch", "csrc", "lazy", "core", "shape_inference.h"
).read_text()
)
assert len(upstream_shape_inference_decls) > 0
shape_inference_defs = extract_signatures(
self.backend_path.joinpath("shape_inference.cpp").read_text()
)
assert len(shape_inference_decls) > len(shape_inference_defs)
missing_defs = (
shape_inference_decls
- upstream_shape_inference_decls
- shape_inference_defs
)
if missing_defs:
self.generated_path.joinpath("shape_inference.cpp").write_text(
2022-06-10 22:25:10 +08:00
dedent(
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel or shape definition
#include "shape_inference.h"
#include "base_lazy_backend/utils/exception.h"
2022-06-10 22:25:10 +08:00
namespace torch {{
namespace lazy {{
{}
}} // namespace lazy
}} // namespace torch
"""
).format(
"".join(
dedent(
f"""
std::vector<torch::lazy::Shape> {name}({args}) {{
UNIMPLEMENTED_FUNCTION_ERROR();
}}
"""
)
for name, args in map(
global_signatures.get, sorted(missing_defs)
)
2022-03-24 22:15:43 +08:00
)
)
)
2022-06-10 22:25:10 +08:00
unnecessary_defs = shape_inference_defs - shape_inference_decls
if unnecessary_defs:
unnecessary_defs = "\n\t".join(
f"{name}({args})"
for name, args in map(global_signatures.get, unnecessary_defs)
)
warnings.warn(
f"Unnecessary shape inference definitions found for:\n\t{unnecessary_defs}"
)
def generate_backend(self):
logging.info("Running Lazy Tensor Autogen")
2022-06-10 22:25:10 +08:00
# No fallback code allowed
def gen_fallback_code(*args, **kwargs):
return ""
torchgen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
torchgen.gen_lazy_tensor.run_gen_lazy_tensor(
backend_name="TorchMlir",
aten_path=str(TORCHGEN_DIR.joinpath("packaged", "ATen")),
2022-06-10 22:25:10 +08:00
source_yaml=str(self.source_yaml),
output_dir=str(self.generated_path),
2022-06-10 22:25:10 +08:00
dry_run=False,
impl_path=str(self.backend_path.joinpath("mlir_native_functions.cpp")),
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
tensor_class=self.tensor_class,
tensor_class_hdr="base_lazy_backend/tensor.h",
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
2022-06-10 22:25:10 +08:00
lazy_ir_generator=GenMlirLazyIr,
2022-03-24 22:15:43 +08:00
)
2022-06-10 22:25:10 +08:00
def __call__(self):
self.generate_native_functions()
self.generate_shape_inference()
self.generate_backend()
2022-03-24 22:15:43 +08:00
2022-06-10 22:25:10 +08:00
def main(args):
generator = GenTorchMlirLTC(args.binary_dir)
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
hash_file = generator.binary_dir.joinpath("generated_backend.hash")
2022-03-24 22:15:43 +08:00
prev_hash = None
if hash_file.exists():
prev_hash = hash_file.read_text().strip()
2022-06-10 22:25:10 +08:00
new_hash = generator.calculate_hash()
2022-03-24 22:15:43 +08:00
if args.force or new_hash != prev_hash:
2022-06-10 22:25:10 +08:00
generator()
hash_file.write_text(new_hash)
2022-03-24 22:15:43 +08:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--binary_dir",
type=str,
default=os.getenv(
"TORCH_MLIR_BINARY_DIR",
TORCH_MLIR_DIR.joinpath("build"),
),
)
2022-03-24 22:15:43 +08:00
parser.add_argument(
"-f",
"--force",
action="store_true",
)
parser.add_argument(
"-d",
"--debug",
help="Print lots of debugging statements",
action="store_const",
dest="loglevel",
const=logging.DEBUG,
default=logging.WARNING,
)
parser.add_argument(
"-v",
"--verbose",
help="Be verbose",
action="store_const",
dest="loglevel",
const=logging.INFO,
)
args = parser.parse_args()
logging.basicConfig(level=args.loglevel)
main(args)