Got LTC working until compile (#689)

pull/1125/head
Jae Hoon (Antonio) Kim 2022-03-24 10:15:43 -04:00 committed by Henry Tu
parent 58338f79a1
commit c3b20e444c
22 changed files with 1937 additions and 215 deletions

2
.gitignore vendored
View File

@ -29,7 +29,7 @@ bazel-*
/python/torch_mlir/csrc/backend/LazyLazyIr.h
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/backend/LazyShapeInference.cpp
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/backend/RegisterLazy.cpp
# Libraries

View File

@ -0,0 +1,321 @@
import argparse
import hashlib
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from shutil import which
from textwrap import dedent
import yaml
TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
TORCH_DIR = TORCH_MLIR_DIR.parent.joinpath("pytorch")
sys.path.append(str(TORCH_DIR.joinpath("tools")))
# PyTorch's LTC backend autogen script
import codegen.dest.lazy_ir
import codegen.gen_lazy_tensor
from codegen.api.lazy import LazyIrSchema
from codegen.gen import get_grouped_native_functions, parse_native_yaml
from codegen.model import NativeFunctionsGroup
def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
print("Generating Native Functions Yaml")
native_yaml_path = TORCH_DIR.joinpath(
"aten", "src", "ATen", "native", "native_functions.yaml"
)
parsed_yaml = parse_native_yaml(native_yaml_path)
native_functions = parsed_yaml.native_functions
grouped_native_functions = get_grouped_native_functions(native_functions)
def get_native_function_name(f):
func = f.func if hasattr(f, "func") else f.functional.func
return str(func.name)
aten_funcs = set(map(get_native_function_name, grouped_native_functions))
with config_path.open() as f:
config = yaml.load(f, yaml.CLoader)
# List of unsupported ops in LTC autogen because of some error
blacklist = config.get("blacklist", [])
# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported = config.get("supported", [])
if which("rg") is not None: # use ripgrep if available as its much faster
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
else:
cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"]
output = (
subprocess.check_output(
cmd + [str(torch_ops_file)],
encoding="utf-8",
)
.strip()
.split(os.linesep)
)
# process ops list
ops = []
supported_ops = []
skipped = []
for op in output:
op = op[6:]
opname = op.split(".")[0]
if opname in blacklist or op in blacklist:
continue
if opname in supported:
supported_ops.append(op)
continue
if op not in aten_funcs:
skipped.append(op)
continue
ops.append(op)
opnames = sorted(set(ops))
# Additional ops to support that are not supported by Torch-MLIR explicitly
supported_ops.extend(config.get("additional_ops", []))
with out_file.open("w") as f:
yaml.dump(
{
"backend": "Lazy",
"cpp_namespace": "torch_lazy_tensors",
"full_codegen": opnames,
"supported": sorted(supported_ops),
},
f,
default_flow_style=False,
)
f.write(
dedent(
"""
# Skipped ops (supported by Torch-MLIR but no equivalent native function)
"""
)
+ os.linesep.join(f"# - {op}" for op in sorted(skipped))
)
return parsed_yaml, grouped_native_functions
@dataclass(frozen=True)
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
lowering_function_type: str = "torch::lazy::MlirFunction"
lowering_context_type: str = "torch::lazy::MlirLoweringContext*"
lowering_return_type: str = "torch::lazy::MlirOpVector"
def lowering_body(self, f):
func = (
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
)
schema = LazyIrSchema(func)
return f"""
UNIMPLEMENTED_ERROR(
"'{func}' lowering not yet implemented"
);
""".rstrip()
def generate_backend(
source_yaml: Path,
backend_path: Path,
parsed_yaml: dict,
grouped_native_functions: list,
):
print("Running Lazy Tensor Autogen")
# No fallback code allowed
def gen_fallback_code(*args, **kwargs):
return ""
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
codegen.gen_lazy_tensor.run(
backend_name="TorchMlir",
source_yaml=str(source_yaml),
output_dir=str(backend_path),
dry_run=False,
impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")),
gen_ts_lowerings=False,
node_base="torch::lazy::MlirNode",
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
shape_inference_hdr=str(backend_path.joinpath("LazyShapeInference.h")),
lazy_ir_cls=MlirLazyIr,
)
# Remove lazy_tensor_core imports
subprocess.check_call(
[
"sed",
"-i",
"/lazy_tensor_core/d",
str(backend_path.joinpath("LazyNativeFunctions.cpp")),
]
)
# programmatically check shape inference declarations
import re
sig_re = re.compile(
r"std::vector<Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
)
global_signatures = {}
def extract_signatures(path):
signatures = set()
for name, args in sig_re.findall(path.read_text()):
signature = re.sub(r"\s+", "", f"{name}({args})")
global_signatures[signature] = (name, args)
signatures.add(signature)
return signatures
upstream_shape_inference_decls = extract_signatures(
TORCH_DIR.joinpath("torch", "csrc", "lazy", "core", "shape_inference.h")
)
assert len(upstream_shape_inference_decls) > 0
shape_inference_decls = extract_signatures(
backend_path.joinpath("LazyShapeInference.h")
)
assert len(shape_inference_decls) > 0
shape_inference_defs = extract_signatures(
backend_path.joinpath("LazyShapeInference.cpp")
)
assert len(shape_inference_defs) > 0
assert len(shape_inference_decls) > len(shape_inference_defs)
missing_defs = (
shape_inference_decls
- upstream_shape_inference_decls
- shape_inference_defs
)
if missing_defs:
backend_path.joinpath("GenLazyShapeInference.cpp").write_text(
dedent(
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel or shape definition
#include "LazyShapeInference.h"
#include "../utils/exception.h"
namespace torch {{
namespace lazy {{
{}
}} // namespace lazy
}} // namespace torch
"""
).format(
"".join(
dedent(
f"""
std::vector<Shape> {name}({args}) {{
UNIMPLEMENTED_FUNCTION_ERROR();
}}
"""
)
for name, args in map(
global_signatures.get, sorted(missing_defs)
)
)
)
)
unnecessary_defs = shape_inference_defs - shape_inference_decls
if unnecessary_defs:
unnecessary_defs = "\n\t".join(
f"{name}({args})"
for name, args in map(global_signatures.get, unnecessary_defs)
)
warnings.warn(
f"Unnecessary shape inference definitions found for:\n\t{unnecessary_defs}"
)
def main(args):
script_path = Path(__file__).resolve()
config_path = (
Path(__file__).resolve().parent.joinpath("autogen_ltc_backend.yaml")
)
torch_ops_file = TORCH_MLIR_DIR.joinpath(
"include",
"torch-mlir",
"Dialect",
"Torch",
"IR",
"GeneratedTorchOps.td",
)
assert torch_ops_file.exists()
native_functions = TORCH_MLIR_DIR.joinpath(
"generated_native_functions.yaml"
)
backend_path = TORCH_MLIR_DIR.joinpath(
"python", "torch_mlir", "csrc", "backend"
)
assert backend_path.is_dir()
prev_hash = None
hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash")
if hash_file.exists():
prev_hash = hash_file.read_text().strip()
m = hashlib.sha256()
m.update(script_path.read_bytes())
m.update(config_path.read_bytes())
m.update(torch_ops_file.read_bytes())
if native_functions.exists():
m.update(native_functions.read_bytes())
shape_inference_headers = backend_path.joinpath("LazyShapeInference.h")
if shape_inference_headers.exists():
m.update(shape_inference_headers.read_bytes())
shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp")
if shape_inference_defs.exists():
m.update(shape_inference_defs.read_bytes())
new_hash = m.hexdigest().strip()
if args.force or new_hash != prev_hash:
hash_file.write_text(new_hash)
parsed_yaml, grouped_native_functions = generate_native_functions(
config_path, torch_ops_file, native_functions
)
generate_backend(
native_functions,
backend_path,
parsed_yaml,
grouped_native_functions,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-f",
"--force",
action="store_true",
)
main(parser.parse_args())

View File

@ -0,0 +1,52 @@
blacklist:
# List of unsupported ops in LTC autogen because of some error
- arange # Error: Code below assumes there is at least one tensor arg
- contiguous # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- full # Error: Code below assumes there is at least one tensor arg
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
- index_put # Error: TODO not sure if there are other valid types to handle here
- index_put_ # Error: TODO not sure if there are other valid types to handle here
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
- ones # Error: Code below assumes there is at least one tensor arg
- ones_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- resize_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- stack # Error: TODO not sure if there are other valid types to handle here
- to.dtype # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- to.other # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- uniform_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- zeros # Error: Code below assumes there is at least one tensor arg
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
# Additional ops which autogen is supported for but don't compile yet
- item
- size
- where
- copy_
- _to_copy
- log_softmax # Not inherently differentiable. Needs to be decomposed.
- linear # Not inherently differentiable. Needs to be decomposed.
# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported:
# - bernoulli
# - bernoulli_
- cat
- clone
- empty
- expand
- fill_
# - native_batch_norm_backward
- native_batch_norm
- permute
- repeat
- squeeze
- t
- unsqueeze
- view
additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
- _copy_from
- _copy_from_and_resize

View File

@ -0,0 +1,4 @@
BasedOnStyle: LLVM
AlignAfterOpenBracket: AlwaysBreak # BlockIndent
PointerAlignment: Left
ReflowComments: false

View File

@ -25,9 +25,11 @@ add_library(torch_mlir_ltc_backend SHARED
backend/backend_impl.cpp
backend/LazyNativeFunctions.cpp
backend/LazyShapeInference.cpp
backend/GenLazyShapeInference.cpp
backend/mlir_lowering_context.cpp
backend/mlir_node.cpp
backend/RegisterLazy.cpp
tensor_aten_ops.cpp
)
target_link_libraries(torch_mlir_ltc_backend
@ -40,10 +42,10 @@ target_link_libraries(torch_mlir_ltc_backend
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
set_target_properties(torch_mlir_ltc_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/"
OUTPUT_NAME _MLIR_LTC
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
OUTPUT_NAME lib_mlir_ltc
PREFIX ""
SUFFIX ".so"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
LINK_FLAGS "-rdynamic"
)

View File

@ -1,4 +1,4 @@
# Torch-MLIR Lazy Tensor Core Backend
#Torch - MLIR Lazy Tensor Core Backend
Contained within this directory are the components that implements the
Torch-MLIR LTC backend.

View File

@ -0,0 +1,21 @@
//===- LazyShapeInference.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "LazyShapeInference.h"
#include "../utils/exception.h"
namespace torch {
namespace lazy {
std::vector<Shape> compute_shape_detach(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,89 @@
//===- LazyShapeInference.h -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Optional.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <vector>
namespace torch {
namespace lazy {
// clang-format off
TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size);
TORCH_API std::vector<Shape> compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps, bool cudnn_enabled);
TORCH_API std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, const at::Tensor & p, c10::optional<at::Generator> generator);
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator);
TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, const c10::optional<at::Tensor> & weights, int64_t minlength);
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_detach(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_dropout(const at::Tensor & input, double p, bool train);
TORCH_API std::vector<Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
TORCH_API std::vector<Shape> compute_shape_expand_as(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim);
TORCH_API std::vector<Shape> compute_shape_floor_divide(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_fmod(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_full_like(const at::Tensor & self, const at::Scalar & fill_value, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<Shape> compute_shape_hardswish(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val);
TORCH_API std::vector<Shape> compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
TORCH_API std::vector<Shape> compute_shape_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps, bool cudnn_enable);
TORCH_API std::vector<Shape> compute_shape_log_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim);
TORCH_API std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask);
TORCH_API std::vector<Shape> compute_shape_matmul(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_max(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode);
TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_relu_(at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape);
TORCH_API std::vector<Shape> compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index);
TORCH_API std::vector<Shape> compute_shape_slice(const at::Tensor & self, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
TORCH_API std::vector<Shape> compute_shape_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_square(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_std(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<Shape> compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_sub_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
// clang-format on
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,58 @@
//===- aten_eager_fallback.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.cpp
//===----------------------------------------------------------------------===//
#include <iostream>
#include <unordered_map>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/metrics.h>
#include "../utils/exception.h"
#include "aten_eager_fallback.h"
namespace torch_lazy_tensors {
static std::unordered_map<std::string, ::torch::lazy::Counter*>
_eager_fallback_counters;
bool force_eager_fallback(c10::Symbol op) {
static char* force_str = std::getenv("LTC_FORCE_FALLBACK");
if (force_str != nullptr) {
static auto force_sym = c10::Symbol::fromQualString(std::string(force_str));
if (op == force_sym) {
std::cout << "MLIR force_eager_fallback(" << force_str << "): true"
<< std::endl;
return true;
}
}
std::cout << "MLIR force_eager_fallback(" << op.toQualString() << "): false"
<< std::endl;
return false;
}
void ltc_eager_fallback(
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto name = c10::toString(op.operator_name());
UNSUPPORTED_ERROR(
"MLIR ltc_eager_fallback is not supported for op: " << name);
}
std::function<void(void)> register_mlir_ltc_eager_fallback;
TORCH_LIBRARY_IMPL(_, Lazy, m) {
register_mlir_ltc_eager_fallback = [&]() {
m.fallback(
torch::CppFunction::makeFromBoxedFunction<&ltc_eager_fallback>());
};
}
} // namespace torch_lazy_tensors

View File

@ -0,0 +1,27 @@
//===- aten_eager_fallback.h ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// Facilitates eager fallback behaviour
//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.h
//===----------------------------------------------------------------------===//
#pragma once
#include <ATen/native/CPUFallback.h>
namespace torch_lazy_tensors {
bool force_eager_fallback(c10::Symbol op);
void ltc_eager_fallback(
const c10::OperatorHandle& op, torch::jit::Stack* stack);
extern TORCH_API std::function<void(void)> register_mlir_ltc_eager_fallback;
} // namespace torch_lazy_tensors

View File

@ -0,0 +1,399 @@
//===- aten_ltc_mlir_type.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_ltc_ts_type.cpp
//===----------------------------------------------------------------------===//
#include <ATen/Operators.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/result_type.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
#include <torch/library.h>
#include "ATen/MetaFunctions.h"
#include <torch/csrc/lazy/core/tensor_impl.h>
#include "../tensor_aten_ops.h"
#include "../utils/exception.h"
#include "../utils/sys_utils.h"
#include "LazyNativeFunctions.h"
#include "LazyShapeInference.h"
namespace torch_lazy_tensors {
namespace {
void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) {
CHECK(type1 != at::kBool || type2 != at::kBool)
<< "Subtraction, the `-` operator, with two bool tensors is not "
"supported. Use the `^` or `logical_xor()` operator instead.";
CHECK(type1 != at::kBool && type2 != at::kBool)
<< "Subtraction, the `-` operator, with a bool tensor is not "
"supported. If you are trying to invert a mask, use the `~` or "
"`logical_not()` operator instead.";
}
std::pair<torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr>
GetBinaryOperands(const at::Tensor& self, const at::Tensor& other) {
torch::lazy::LazyTensorPtr self_tensor;
torch::lazy::LazyTensorPtr other_tensor;
auto self_xtensor = torch::lazy::TryGetLtcTensor(self);
if (!self_xtensor) {
other_tensor = torch::lazy::TryGetLtcTensor(other);
self_tensor = GetOrCreateLtcTensor(self, other_tensor->GetDevice());
} else {
self_tensor = self_xtensor;
other_tensor = GetOrCreateLtcTensor(other, self_tensor->GetDevice());
}
return std::pair<torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr>(
self_tensor, other_tensor);
}
template <typename B>
at::Tensor
DoBinaryOp(const at::Tensor& self, const at::Tensor& other, const B& bin_op) {
at::ScalarType dtype = at::result_type(self, other);
std::pair<torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr> operands =
GetBinaryOperands(
torch::lazy::UnwrapNumber(self, dtype),
torch::lazy::UnwrapNumber(other, dtype));
torch::lazy::LazyTensorPtr result = bin_op(operands.first, operands.second);
return torch::lazy::CreateAtenFromLtcTensor(result);
}
template <typename B>
at::Tensor
DoBinaryOp(const at::Tensor& self, const at::Scalar& other, const B& bin_op) {
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::GetLtcTensor(self);
torch::lazy::LazyTensorPtr result = bin_op(self_tensor, other);
return torch::lazy::CreateAtenFromLtcTensor(result);
}
at::Tensor subtensor(const at::Tensor& tensor, int dim, int groups, int g) {
if (!tensor.defined()) {
return at::Tensor();
}
int64_t n = tensor.sizes()[dim] / groups;
return tensor.narrow(dim, n * g, n).contiguous();
}
at::Tensor CreateLtcTensor(
const at::Tensor& tensor,
const c10::optional<torch::lazy::BackendDevice>& device) {
if (tensor.defined() && device) {
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::LazyTensor::Create(tensor, *device));
}
return tensor;
}
c10::optional<torch::lazy::BackendDevice>
GetLtcDevice(const c10::optional<c10::Device>& device) {
if (!device) {
return c10::nullopt;
}
if (device->type() != at::kLazy) {
return c10::nullopt;
}
return torch::lazy::atenDeviceToBackendDevice(*device);
}
} // namespace
// at::Tensor LazyNativeFunctions::bernoulli(
// const at::Tensor& self, c10::optional<at::Generator> generator) {
// TORCH_LAZY_FN_COUNTER("lazy::");
// if (generator.has_value() && generator->defined()) {
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value");
// }
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
// UNIMPLEMENTED_FUNCTION_ERROR();
// // return torch::lazy::CreateAtenFromLtcTensor(
// // lazy_tensor_aten_ops::bernoulli(self_tensor));
// }
// at::Tensor& LazyNativeFunctions::bernoulli_(
// at::Tensor& self, double p, c10::optional<at::Generator> generator) {
// TORCH_LAZY_FN_COUNTER("lazy::");
// if (generator.has_value() && generator->defined()) {
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value");
// }
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
// UNIMPLEMENTED_FUNCTION_ERROR();
// // lazy_tensor_aten_ops::bernoulli_(self_tensor, p);
// // return self;
// }
at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto lazy_tensors = torch::lazy::GetLtcTensors(tensors);
std::vector<torch::lazy::Value> values;
values.reserve(lazy_tensors.size());
for (auto& tensor : lazy_tensors) {
values.emplace_back(tensor->GetIrValue());
}
auto shapes = torch::lazy::compute_shape_cat(tensors, dim);
UNIMPLEMENTED_FUNCTION_ERROR();
// auto node =
// torch::lazy::MakeNode<ir::ops::Cat>(values, dim, std::move(shapes));
// auto result = torch::lazy::CreateAtenFromLtcTensor(
// torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 0),
// lazy_tensors[0]->GetDevice()));
// return result;
}
at::Tensor LazyNativeFunctions::clone(
const at::Tensor& self, c10::optional<at::MemoryFormat> memory_format) {
auto self_lt = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
}
at::Tensor LazyNativeFunctions::_copy_from(
const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
if (!self_tensor) {
// providing a new 'eager' value (self) for an existing lazy tensor (dst)
static bool sync_update =
sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true);
CHECK(dst_tensor);
dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
} else if (!dst_tensor) {
// materializing a lazy tensor (self) and copying its value into eager
// tensor (dst)
// detached=false lets us skip a copy in `ToTensor`, which should be safe
// becuase we are only going to use the tensor for dst.copy_()
CHECK(self_tensor);
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/false);
at::Tensor typed_tensor =
torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
dst.resize_as_(typed_tensor).copy_(typed_tensor);
} else {
// Copying one lazy tensor to another
if (!dst_tensor->CurrentIrValue()) {
// if dest is not backed by IR (e.g. result of some lazy operation),
// then it should have at::Tensor data backing it instead
auto dst_tensor_data = dst_tensor->CurrentTensorData();
CHECK(dst_tensor_data);
auto src_tensor_data = self_tensor->CurrentTensorData();
if (src_tensor_data) {
// both src/dst are simply backed by at::Tensor data, no IR- do a
// straightforward copy
dst_tensor_data->copy_(*src_tensor_data);
} else {
// src needs to be materialized before its result can be used for a copy
// into dst
// since we use the src tensor only for making a copy, we don't need to
// detach it
// note: it would be even more efficient if we could cause ToTensor to
// materialize the
// value directly into dst's buffer (that would need to be detached
// though).
dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
}
} else {
lazy_tensor_aten_ops::copy_(dst_tensor, self_tensor);
auto* impl =
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
impl->set_tensor(dst_tensor);
}
}
return dst;
}
at::Tensor LazyNativeFunctions::_copy_from_and_resize(
const at::Tensor& self, const at::Tensor& dst) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
if (!self_tensor) {
CHECK(dst_tensor);
dst_tensor->UpdateFromTensorOut(self);
} else if (!dst_tensor) {
CHECK(self_tensor);
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true);
at::Tensor typed_tensor =
torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
dst.resize_as_(typed_tensor).copy_(typed_tensor);
} else {
// at this point we know dst is a lazy tensor
auto* dest_impl =
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
dest_impl->tensor()->UpdateFromTensorOut(self_tensor);
dest_impl->force_refresh_sizes();
}
return dst;
}
at::Tensor LazyNativeFunctions::empty(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
at::TensorOptions options = at::TensorOptions()
.device(c10::Device(device_type))
.layout(layout)
.pinned_memory(pin_memory)
.dtype(dtype);
auto x_result = at::empty(size, options, memory_format);
return CreateLtcTensor(x_result, GetLtcDevice(device));
}
at::Tensor LazyNativeFunctions::expand(
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::expand(
// torch::lazy::TryGetLtcTensor(self), size.vec()));
}
at::Tensor&
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
lazy_tensor_aten_ops::fill_(self_tensor, value);
return self;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
LazyNativeFunctions::native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr input_tensor = torch::lazy::TryGetLtcTensor(input);
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
torch::lazy::LazyTensorPtr running_mean_tensor =
GetOrCreateLtcTensor(running_mean, device);
torch::lazy::LazyTensorPtr running_var_tensor =
GetOrCreateLtcTensor(running_var, device);
UNIMPLEMENTED_FUNCTION_ERROR();
// auto outputs = lazy_tensor_aten_ops::ts_native_batch_norm(
// torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight,
// device),
// GetOrCreateLtcTensor(bias, device), running_mean_tensor,
// running_var_tensor, training, momentum, eps);
// return
// std::make_tuple(torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
}
// std::tuple<at::Tensor, at::Tensor, at::Tensor>
// LazyNativeFunctions::native_batch_norm_backward(
// const at::Tensor& grad_out, const at::Tensor& input,
// const c10::optional<at::Tensor>& weight,
// const c10::optional<at::Tensor>& running_mean,
// const c10::optional<at::Tensor>& running_var,
// const c10::optional<at::Tensor>& save_mean,
// const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
// std::array<bool, 3> output_mask) {
// TORCH_LAZY_FN_COUNTER("lazy::");
// torch::lazy::LazyTensor grad_out_tensor =
// torch::lazy::TryGetLtcTensor(grad_out);
// const torch::lazy::BackendDevice& device = grad_out_tensor.GetDevice();
// torch::lazy::LazyTensor null_tensor;
// bool running_stats = running_mean && running_mean->defined();
// CHECK_EQ(running_var && running_var->defined(), running_stats);
// UNIMPLEMENTED_FUNCTION_ERROR();
// // auto gradients = lazy_tensor_aten_ops::ts_native_batch_norm_backward(
// // torch::lazy::TryGetLtcTensor(grad_out),
// torch::lazy::TryGetLtcTensor(input),
// // GetOrCreateLtcTensor(weight, device),
// // running_stats ? GetOrCreateLtcTensor(running_mean, device)
// // : null_tensor,
// // running_stats ? GetOrCreateLtcTensor(running_var, device)
// // : null_tensor,
// // GetOrCreateLtcTensor(save_mean, device),
// // GetOrCreateLtcTensor(save_invstd, device), train, eps,
// // output_mask);
// // at::Tensor undefined;
// // return std::make_tuple(
// // output_mask[0] ?
// torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
// // : undefined,
// // output_mask[1] ?
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
// // : undefined,
// // output_mask[2] ?
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
// // : undefined);
// }
at::Tensor
LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::permute(
// self_tensor, torch::lazy::ToI64Vector(dims)));
}
at::Tensor
LazyNativeFunctions::repeat(const at::Tensor& self, at::IntArrayRef repeats) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::repeat(
// torch::lazy::TryGetLtcTensor(self),
// torch::lazy::ToI64Vector(repeats)));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self)));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self),
// dim));
}
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::transpose(
torch::lazy::TryGetLtcTensor(self), 0, 1));
}
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(
// lazy_tensor_aten_ops::unsqueeze(torch::lazy::TryGetLtcTensor(self),
// dim));
}
at::Tensor
LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
lazy_tensor_aten_ops::view(self_tensor, torch::lazy::ToI64Vector(size)));
}
void InitializeAtenBindings() {}
} // namespace torch_lazy_tensors

View File

@ -0,0 +1,157 @@
//===- backend_impl.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.cpp
//===----------------------------------------------------------------------===//
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "backend_impl.h"
#include "mlir_lowering_context.h"
namespace torch {
namespace lazy {
MlirBackendData::MlirBackendData(BackendDevice device, Shape shape)
: BackendData(device, shape) {
PRINT_FUNCTION();
auto info = std::make_shared<MlirBackendData::Info>();
SetInfo(info);
}
MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})) {
PRINT_FUNCTION();
auto info = std::make_shared<MlirBackendData::Info>(scalar);
SetInfo(info);
}
MlirBackendData::MlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape)
: BackendData(device, shape) {
PRINT_FUNCTION();
auto info = std::make_shared<MlirBackendData::Info>(tensor);
SetInfo(info);
}
BackendData::Handle MlirBackendData::GetHandle() {
return reinterpret_cast<int64_t>(this);
}
void MlirBackendData::Assign(const BackendData& data) {
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data.info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
auto new_info = std::make_shared<MlirBackendData::Info>(*info);
SetInfo(new_info);
}
bool MlirBackendData::HasValue() const { return bool(info()); }
/**
* Initialization/Teardown
* */
void MlirBackendImpl::PrepareToExit() const {}
/**
* Data Transfer
* */
BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor(
const at::Tensor& tensor, const Shape& shape,
const BackendDevice& device) const {
PRINT_FUNCTION();
return std::make_shared<MlirBackendData>(tensor, device, shape);
}
BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar(
const at::Scalar& scalar, const BackendDevice& device) const {
PRINT_FUNCTION();
return std::make_shared<MlirBackendData>(scalar, device);
}
BackendDataPtr MlirBackendImpl::CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const {
PRINT_FUNCTION();
return std::make_shared<MlirBackendData>(device, shape);
}
at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {
PRINT_FUNCTION();
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data->info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
return info->tensor;
}
/**
* Lowering, Compilation, Execution
* */
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const {
PRINT_FUNCTION();
return std::make_unique<MlirLoweringContext>(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status));
}
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device) const {
PRINT_FUNCTION();
return std::make_unique<MlirLoweringContext>(
name, std::forward<BackendDevice>(device));
}
/**
* Device Configuration
* */
// Set or get the default device type.
// For backends used with virtual c10:: Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.
// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const {
PRINT_FUNCTION();
return at::DeviceType::CPU;
}
// Query all available backend devices
std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
PRINT_FUNCTION();
return {
GetBackendDevice(c10::Device(c10::kLazy, 0)),
GetBackendDevice(c10::Device(c10::kCPU, 0))};
}
// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const {
PRINT_FUNCTION();
return BackendDevice(GetDefaultDeviceType(), device.index());
}
} // namespace lazy
} // namespace torch

View File

@ -23,129 +23,132 @@
namespace torch {
namespace lazy {
class MlirBackendData : public torch::lazy::BackendData {
public:
struct Info;
MlirBackendData(torch::lazy::BackendDevice device, torch::lazy::Shape shape);
MlirBackendData(const at::Scalar& scalar, torch::lazy::BackendDevice device);
MlirBackendData(const at::Tensor& tensor, torch::lazy::BackendDevice device, torch::lazy::Shape shape);
virtual torch::lazy::BackendData::Handle GetHandle() override;
virtual void Assign(const torch::lazy::BackendData& data) override;
virtual bool HasValue() const override;
};
class MlirBackendImpl : public torch::lazy::BackendImplInterface {
class TORCH_API MlirBackendData : public BackendData {
public:
/**
* Initialization/Teardown
* */
virtual void PrepareToExit() const override;
struct Info : public BackendData::Info {
at::Tensor tensor;
c10::optional<at::Scalar> scalar;
bool requires_grad;
/**
* Configuration
* */
// virtual void SetRngSeed(size_t seed) const = 0;
Info() {}
Info(const Info& other)
: tensor{other.tensor}, scalar{other.scalar},
requires_grad{other.requires_grad} {}
Info(const at::Tensor& tensor)
: tensor{tensor}, requires_grad{tensor.requires_grad()} {}
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
};
/**
* Data Transfer
* */
MlirBackendData(BackendDevice device, Shape shape);
MlirBackendData(const at::Scalar& scalar, BackendDevice device);
MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
virtual torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
const at::Tensor& tensor,
const torch::lazy::Shape& shape,
const torch::lazy::BackendDevice& device
) const override;
virtual BackendData::Handle GetHandle() override;
virtual torch::lazy::BackendDataPtr MakeComputationDataFromScalar(
const at::Scalar& scalar,
const torch::lazy::BackendDevice& device
) const override;
virtual torch::lazy::BackendDataPtr CreateDataPlaceholder(
const torch::lazy::BackendDevice& device, const torch::lazy::Shape& shape
) const override;
virtual at::Tensor MakeTensorFromComputationData(
const torch::lazy::BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type
) const override;
/**
* Lowering, Compilation, Execution
* */
virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
const std::string& name,
torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status
) const override;
virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
const std::string& name, torch::lazy::BackendDevice device
) const override;
// TODO(whc) need to keep this?
// virtual std::vector<std::string> GetCompilationDevices(
// const std::string& device, c10::ArrayRef<std::string> devices
// ) const = 0;
// virtual std::vector<torch::lazy::ComputationPtr> Compile(
// std::vector<torch::lazy::ComputationPtr> instances
// ) const = 0;
// virtual std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
// torch::lazy::Computation& computation,
// c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
// const torch::lazy::BackendDevice& device
// ) const = 0;
/**
* Device Configuration
* */
// Set or get the default device type.
// For backends used with virtual c10:: Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.
// virtual std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const = 0;
// virtual void SetDefaultDeviceType(std::string device_type) = 0;
// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
virtual at::DeviceType EagerFallbackDeviceType() const override;
// Query all available backend devices
virtual std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override;
// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
virtual torch::lazy::BackendDevice GetBackendDevice(c10::Device device) const override;
/**
* Debug/Metrics
* */
// virtual std::map<std::string, Metric> GetMetrics() const = 0;
// virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
// virtual std::string GetComputationBackendText(
// const torch::lazy::ComputationPtr computation
// ) const = 0;
virtual void Assign(const BackendData& data) override;
virtual bool HasValue() const override;
};
} // lazy
} // torch
class TORCH_API MlirBackendImpl : public BackendImplInterface {
public:
virtual ~MlirBackendImpl() = default;
/**
* Initialization/Teardown
* */
virtual void PrepareToExit() const override;
/**
* Configuration
* */
// virtual void SetRngSeed(size_t seed) const = 0;
/**
* Data Transfer
* */
virtual BackendDataPtr MakeComputationDataFromTensor(
const at::Tensor& tensor, const Shape& shape,
const BackendDevice& device) const override;
virtual BackendDataPtr MakeComputationDataFromScalar(
const at::Scalar& scalar, const BackendDevice& device) const override;
virtual BackendDataPtr CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const override;
virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const override;
/**
* Lowering, Compilation, Execution
* */
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<Node*> post_order,
Util::EmissionMap emit_status) const override;
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name, BackendDevice device) const override;
// TODO(whc) need to keep this?
// virtual std::vector<std::string> GetCompilationDevices(
// const std::string& device, c10::ArrayRef<std::string> devices
// ) const = 0;
// virtual std::vector<ComputationPtr> Compile(
// std::vector<ComputationPtr> instances
// ) const = 0;
// virtual std::vector<BackendDataPtr> ExecuteComputation(
// Computation& computation,
// c10::ArrayRef<BackendDataPtr> arguments,
// const BackendDevice& device
// ) const = 0;
/**
* Device Configuration
* */
// Set or get the default device type.
// For backends used with virtual c10:: Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.
// virtual std::shared_ptr<BackendDeviceType> GetDefaultDeviceType() const =
// 0;
// virtual void SetDefaultDeviceType(std::string device_type) = 0;
// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
virtual at::DeviceType EagerFallbackDeviceType() const override;
// Query all available backend devices
virtual std::vector<BackendDevice> GetBackendDevices() const override;
// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
virtual BackendDevice GetBackendDevice(c10::Device device) const override;
/**
* Debug/Metrics
* */
// virtual std::map<std::string, Metric> GetMetrics() const = 0;
// virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
// virtual std::string GetComputationBackendText(
// const ComputationPtr computation
// ) const = 0;
};
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,93 @@
//===- mlir_lowering_context.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/ts_lowering_context.cpp
//===----------------------------------------------------------------------===//
#include <iostream>
#include "../utils/exception.h"
#include "mlir_lowering_context.h"
namespace torch {
namespace lazy {
MlirLoweringContext::MlirLoweringContext(
const std::string& name, BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)) {}
MlirLoweringContext::MlirLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
: LoweringContext(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)) {}
int MlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
const std::vector<torch::lazy::Shape>&
MlirComputation::parameter_shapes() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
const std::vector<std::string>& MlirComputation::parameter_names() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
const torch::lazy::Shape& MlirComputation::result_shape() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Get the shape of the result tuple component, given by index.
torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
const torch::lazy::Node* node;
auto it = emitted_outputs_.find(output);
if (it == emitted_outputs_.end()) {
node = output.node;
auto post_order = Util::ComputePostOrder(node, &emit_status_);
for (auto po_node : post_order) {
// TODO: uncomment after lowering is implemented
// bool ok = lowering_->Lower(node);
// TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
}
emitted_outputs_[output] = node;
} else {
node = it->second;
}
result_tuple_.emplace_back(node);
return result_tuple_.size() - 1;
}
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void MlirLoweringContext::AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
ComputationPtr MlirLoweringContext::Build() {
for (const torch::lazy::Node* output : result_tuple_) {
}
return std::make_shared<MlirComputation>();
}
} // namespace lazy
} // namespace torch

View File

@ -10,60 +10,58 @@
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h
//===----------------------------------------------------------------------===//
#pragma once
#include <vector>
#include <torch/csrc/lazy/backend/lowering_context.h>
namespace torch {
namespace lazy {
class MlirComputation : public torch::lazy::Computation {
public:
int parameters_size() const override;
class TORCH_API MlirComputation : public torch::lazy::Computation {
public:
int parameters_size() const override;
virtual const std::vector<torch::lazy::Shape>& parameter_shapes() const override;
virtual const std::vector<torch::lazy::Shape>&
parameter_shapes() const override;
virtual const std::vector<std::string>& parameter_names() const override;
virtual const std::vector<std::string>& parameter_names() const override;
virtual const torch::lazy::Shape& result_shape() const override;
virtual const torch::lazy::Shape& result_shape() const override;
};
class MlirLoweringContext : public torch::lazy::LoweringContext {
public:
class TORCH_API MlirLoweringContext : public torch::lazy::LoweringContext {
public:
MlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device);
MlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);
MlirLoweringContext(const std::string& name, torch::lazy::BackendDevice device);
MlirLoweringContext(const std::string& name,
torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);
// Get the shape of the result tuple component, given by index.
virtual torch::lazy::Shape GetResultShape(size_t index) const override;
// Get the shape of the result tuple component, given by index.
virtual torch::lazy::Shape GetResultShape(size_t index) const override;
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
virtual size_t AddResult(const torch::lazy::Output& output) override;
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
virtual size_t AddResult(const torch::lazy::Output& output) override;
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
virtual void AddParameter(const torch::lazy::Output& output,
size_t index,
const torch::lazy::Shape& shape,
const std::string& name) override;
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
virtual void AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) override;
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
virtual torch::lazy::ComputationPtr Build() override;
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
virtual torch::lazy::ComputationPtr Build() override;
private:
std::vector<const torch::lazy::Node*> result_tuple_;
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
private:
std::vector<const torch::lazy::Node*> result_tuple_;
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
};
} // namespace lazy
} // namespace torch
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,130 @@
//===- mlir_node.cpp ------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.cpp
//===----------------------------------------------------------------------===//
#include <torch/csrc/lazy/core/cache.h>
#include "../utils/exception.h"
#include "mlir_node.h"
namespace torch {
namespace lazy {
namespace {
hash_t OperandHashes(
const OpList& operands, const hash_t& seed, const bool bakeInSizes) {
hash_t hash = seed;
for (auto& operand : operands) {
if (!operand) {
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
continue;
}
auto operand_hash =
bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
hash = HashCombine(hash, operand_hash);
}
return hash;
}
hash_t GetOpHash(
OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) {
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
return HashCombine(h, hash_seed);
}
} // namespace
MlirNode::MlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
hash_t hash_seed)
: Node(
op, num_outputs,
/* node_hash */ HashCombine(op.hash(), hash_seed),
/* dag_hash */
[&](bool bakeInSizes) -> hash_t {
return OperandHashes(
operands, HashCombine(op.hash(), hash_seed), bakeInSizes);
}),
shapes_(std::move(shapes)) {
for (auto& operand : operands) {
// Ideally, optional operands should be filtered by the leaf node classes,
// but it's just much easier to do it here.
if (!operand) {
continue;
}
AddOperand(operand.node, operand.index);
}
}
MlirNode::MlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
size_t num_outputs, hash_t hash_seed)
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
shapes_.push_back(GetOpShape(shape_fn));
}
MlirNode::MlirNode(
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
void MlirNode::SetShapeDeferred(const std::function<Shape()>& shape_fn) {
shapes_.push_back(GetOpShape(shape_fn));
}
MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
: Node(op, num_outputs, [&](bool bakeInSizes) -> hash_t {
return GetOpHash(op, shape, hash_seed, bakeInSizes);
}) {
shapes_.push_back(std::move(shape));
}
using ShapeCache = Cache<hash_t, Shape, HashReducer>;
constexpr const int torch_lazy_shape_cache_size = 4096;
ShapeCache* GetShapeCache() {
static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size);
return cache;
}
Shape MlirNode::GetOpShape(const std::function<Shape()>& shape_fn) const {
ShapeCache* shape_cache = GetShapeCache();
auto shape = shape_cache->Get(hash());
if (shape == nullptr) {
shape = shape_cache->Add(hash(), std::make_shared<Shape>(shape_fn()));
}
return *shape;
}
c10::ArrayRef<Shape> MlirNode::shapes() const { return shapes_; }
const Shape& MlirNode::shape(size_t output_index) const {
return shapes_.at(output_index);
}
const std::vector<Output>& MlirNode::operands() const {
return operands_as_outputs_;
}
const Output& MlirNode::operand(size_t i) const {
return operands_as_outputs_.at(i);
}
void MlirNode::AddOperand(NodePtr node, size_t index) {
CHECK_LT(index, node->num_outputs());
operands_.push_back(std::move(node));
operands_as_outputs_.emplace_back(operands_.back().get(), index);
}
} // namespace lazy
} // namespace torch

View File

@ -14,12 +14,12 @@
#include <ATen/core/interned_strings.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/shape.h>
#include "../utils/exception.h"
#include "aten_eager_fallback.h"
#include "mlir_lowering_context.h"
#include "../utils/exception.h"
namespace torch {
namespace lazy {
@ -27,58 +27,57 @@ namespace lazy {
typedef std::vector<NodePtr> MlirOpVector;
typedef NodePtr MlirFunction;
class TORCH_API MlirNode : public torch::lazy::Node {
class MlirNode : public torch::lazy::Node {
public:
MlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
public:
MlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs = 1, hash_t hash_seed = kHashSeed
);
// Same as the constructor above, but the shape is generated by a function,
// only if needed (shape cache miss).
MlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
// Same as the constructor above, but the shape is generated by a function,
// only if needed (shape cache miss).
MlirNode(
OpKind op, OpList operands,
const std::function<Shape()>& shape_fn,
size_t num_outputs = 1, hash_t hash_seed = kHashSeed
);
// The shape is set later.
MlirNode(
OpKind op, OpList operands, size_t num_outputs = 1,
hash_t hash_seed = kHashSeed);
// The shape is set later.
MlirNode(
OpKind op, OpList operands, size_t num_outputs = 1,
hash_t hash_seed = kHashSeed
);
void SetShapeDeferred(const std::function<Shape()>& shape_fn);
void SetShapeDeferred(const std::function<Shape()>& shape_fn);
// Contructor used to create leaf nodes.
MlirNode(
OpKind op, Shape shape, size_t num_outputs = 1,
hash_t hash_seed = kHashSeed);
// Contructor used to create leaf nodes.
MlirNode(
OpKind op, Shape shape, size_t num_outputs = 1, hash_t hash_seed = kHashSeed
);
Shape GetOpShape(const std::function<Shape()>& shape_fn) const;
Shape GetOpShape(const std::function<Shape()>& shape_fn) const;
// Retrieves the full shape of the IR Node.
c10::ArrayRef<Shape> shapes() const override;
const std::vector<Output>& operands() const override;
// Retrieves the shape of the output at a given index.
const Shape& shape(size_t output_index = 0) const override;
const Output& operand(size_t i) const override;
const std::vector<Output>& operands() const override;
virtual MlirOpVector Lower(
MlirFunction function,
MlirLoweringContext* loctx
) const = 0;
const Output& operand(size_t i) const override;
private:
// Adds node's index output number as operand.
void AddOperand(NodePtr node, size_t index = 0);
virtual MlirOpVector
Lower(MlirFunction function, MlirLoweringContext* loctx) const = 0;
std::vector<Shape> shapes_;
// A node holds a real reference to its operands.
std::vector<NodePtr> operands_;
// Outputs do not hold references on the nodes, and neither do the uses, since
// otherwise we get into circular reference counting.
std::vector<Output> operands_as_outputs_;
private:
// Adds node's index output number as operand.
void AddOperand(NodePtr node, size_t index = 0);
std::vector<Shape> shapes_;
// A node holds a real reference to its operands.
std::vector<NodePtr> operands_;
// Outputs do not hold references on the nodes, and neither do the uses, since
// otherwise we get into circular reference counting.
std::vector<Output> operands_as_outputs_;
};
} // namespace lazy
} // namespace torch
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,242 @@
//===- tensor_aten_ops.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.cpp
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <functional>
#include <ATen/InferSize.h>
#include <c10/util/Optional.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/ir_util.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
#include <torch/csrc/lazy/core/view_ops/permute.h>
#include <torch/csrc/lazy/core/view_ops/view.h>
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
#include "tensor_aten_ops.h"
namespace torch_lazy_tensors {
namespace lazy_tensor_aten_ops {
namespace {
// to enable operator+-*/ for Value
using namespace torch::lazy;
torch::lazy::Value MaybeExpand(
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
if (input.shape().sizes() == target_shape.sizes()) {
return input;
}
return torch::lazy::MakeNode<torch::lazy::Expand>(
input, target_shape.sizes().vec(),
/*is_scalar_expand=*/false);
}
std::vector<int64_t> GetExpandDimensions(
const torch::lazy::Shape& shape, std::vector<int64_t> dimensions) {
CHECK_GE(dimensions.size(), shape.dim()) << shape;
int64_t base = dimensions.size() - shape.dim();
for (size_t i = 0; i < shape.dim(); ++i) {
if (dimensions[base + i] == -1) {
dimensions[base + i] = shape.size(i);
}
}
return dimensions;
}
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
torch::lazy::Shape
BatchNormFeaturesShape(const torch::lazy::LazyTensorPtr& input) {
CHECK(input);
auto input_shape = input->shape().Get();
return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]);
}
// Returns the IR for the given input or the provided default value broadcasted
// to the default shape, if the input is undefined.
torch::lazy::Value GetIrValueOrDefault(
const torch::lazy::LazyTensorPtr& input, const at::Scalar& default_value,
const torch::lazy::Shape& default_shape,
const torch::lazy::BackendDevice& device) {
return input ? input->GetIrValue()
: torch::lazy::LazyGraphExecutor::Get()
->GetIrValueForExpandedScalar(
default_value, default_shape, device);
}
torch::lazy::ViewInfo CreateAsStridedViewInfo(
const torch::lazy::Shape& input_shape, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
torch::lazy::Shape result_shape =
torch::lazy::Shape(input_shape.scalar_type(), size);
torch::lazy::AsStridedInfo as_strided_info;
as_strided_info.stride = std::move(stride);
if (storage_offset) {
as_strided_info.offset = *storage_offset;
}
return torch::lazy::ViewInfo(
torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape),
input_shape, std::move(as_strided_info));
}
} // namespace
//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
torch::lazy::LazyTensorPtr as_strided(
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
auto input_shape = input->shape();
return input->CreateViewTensor(CreateAsStridedViewInfo(
input_shape, std::move(size), std::move(stride), storage_offset));
}
void as_strided_(
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
if (input->data()->view == nullptr) {
input->SetIrValue(torch::lazy::MakeNode<torch::lazy::AsStrided>(
input->GetIrValue(), std::move(size), std::move(stride),
storage_offset.value_or(0)));
} else {
auto input_shape = input->shape();
input->SetSubView(CreateAsStridedViewInfo(
input_shape, std::move(size), std::move(stride), storage_offset));
}
}
torch::lazy::LazyTensorPtr
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size) {
auto input_shape = input->shape();
return torch::lazy::LazyTensor::Create(
torch::lazy::MakeNode<torch::lazy::Expand>(
input->GetIrValue(),
GetExpandDimensions(input_shape.Get(), std::move(size)),
/*is_scalar_expand=*/false),
input->GetDevice());
}
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) {
torch::lazy::Value constant =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
value, input->shape(), input->GetDevice());
input->SetInPlaceIrValue(std::move(constant));
}
torch::lazy::LazyTensorPtr narrow(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t length) {
auto input_shape = input->shape();
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
torch::lazy::Shape narrow_shape = input_shape;
narrow_shape.set_size(dim, length);
torch::lazy::ViewInfo::Type view_type =
(input_shape.Get().numel() == narrow_shape.numel())
? torch::lazy::ViewInfo::Type::kReshape
: torch::lazy::ViewInfo::Type::kNarrow;
torch::lazy::ViewInfo view_info(
view_type, std::move(narrow_shape), input_shape);
view_info.indices[dim] =
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
return input->CreateViewTensor(std::move(view_info));
}
torch::lazy::LazyTensorPtr
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims) {
auto input_shape = input->shape();
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape,
torch::lazy::GetCanonicalDimensionIndices(dims, input_shape.Get().dim()));
return input->CreateViewTensor(std::move(view_info));
}
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
if (input->GetDevice() == src->GetDevice()) {
torch::lazy::Value copy_value;
if (input->dtype() == src->dtype()) {
copy_value = src->GetIrValue();
} else {
copy_value = torch::lazy::MakeNode<torch::lazy::Cast>(
src->GetIrValue(), input->dtype(), src->dtype());
}
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
} else {
auto input_shape = input->shape();
at::Tensor src_tensor = src->ToTensor(/*detached=*/true);
if (src_tensor.sizes() != input_shape.Get().sizes()) {
src_tensor = src_tensor.expand(input_shape.Get().sizes().vec());
}
input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false);
}
}
torch::lazy::LazyTensorPtr slice(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t end, int64_t step) {
auto input_shape = input->shape();
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
start =
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end);
// PyTorch allows tensor[-1:0] to return a 0-dim tensor.
if (start > end) {
end = start;
}
step = std::min(step, end - start);
torch::lazy::SelectInfo select = {dim, start, end, step};
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kSelect, input_shape, std::move(select));
return input->CreateViewTensor(std::move(view_info));
}
torch::lazy::LazyTensorPtr
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
auto input_shape = input->shape();
auto permute_dims = torch::lazy::MakeTransposePermutation(
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
return input->CreateViewTensor(std::move(view_info));
}
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
auto input_shape = input->shape();
auto permute_dims = torch::lazy::MakeTransposePermutation(
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
return input->ModifyCurrentView(std::move(view_info));
}
torch::lazy::LazyTensorPtr view(
const torch::lazy::LazyTensorPtr& input,
c10::ArrayRef<int64_t> output_size) {
auto input_shape = input->shape().Get();
torch::lazy::Shape shape = torch::lazy::Shape(
input_shape.scalar_type(),
at::infer_size(output_size, input_shape.numel()));
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
return input->CreateViewTensor(std::move(view_info));
}
} // namespace lazy_tensor_aten_ops
} // namespace torch_lazy_tensors

View File

@ -0,0 +1,79 @@
//===- tensor_aten_ops.h --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.h
//===----------------------------------------------------------------------===//
#pragma once
#include <torch/csrc/lazy/core/tensor.h>
namespace torch_lazy_tensors {
namespace lazy_tensor_aten_ops {
//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
// Takes a slice from the input as R1 at the specified offset and reshapes it
// into the provided size.
torch::lazy::LazyTensorPtr as_strided(
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
// In-place version of the method above.
void as_strided_(
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
torch::lazy::LazyTensorPtr
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size);
// Fills the input with the given value.
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value);
// Returns a new tensor that is a narrowed view of the input in the given
// dimension.
torch::lazy::LazyTensorPtr narrow(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t length);
// Permute the dimensions of this tensor according to the given permutation.
torch::lazy::LazyTensorPtr
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims);
// Repeats the input tensor along each dimension by the given number of
// repeats.
torch::lazy::LazyTensorPtr
repeat(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> repeats);
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src);
torch::lazy::LazyTensorPtr slice(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t end, int64_t step);
std::tuple<
torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr>
svd(const torch::lazy::LazyTensorPtr& input, bool some, bool compute_uv);
// Swap given dimensions of the input.
torch::lazy::LazyTensorPtr
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
// In-place version of the method above.
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
// Like reshape, but it returns a view into the original tensor.
torch::lazy::LazyTensorPtr view(
const torch::lazy::LazyTensorPtr& input,
c10::ArrayRef<int64_t> output_size);
} // namespace lazy_tensor_aten_ops
} // namespace torch_lazy_tensors

View File

@ -0,0 +1,23 @@
//===- debug.h ------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <iostream>
#include "sys_utils.h"
static const bool verbose_print_function =
sys_util::GetEnvBool("VERBOSE_PRINT_FUNCTION", false);
#define PRINT_FUNCTION() \
if (verbose_print_function) { \
std::cout << __PRETTY_FUNCTION__ << " (" << __FILE__ << ":" << __LINE__ \
<< ")" << std::endl; \
}

View File

@ -13,17 +13,20 @@
#include <sstream>
#include <string>
#define UNIMPLEMENTED_ERROR(msg) \
{ \
std::ostringstream err; \
err << "Unimplemented Error: " << msg; \
throw std::runtime_error(err.str()); \
}
#define UNIMPLEMENTED_ERROR(msg) \
{ \
std::ostringstream err; \
err << "Unimplemented Error: " << msg; \
throw std::runtime_error(err.str()); \
}
#define UNIMPLEMENTED_FUNCTION_ERROR() \
UNIMPLEMENTED_ERROR( \
"\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__)
#define UNSUPPORTED_ERROR(msg) \
{ \
std::ostringstream err; \
err << "Unsupported Error: " << msg; \
throw std::runtime_error(err.str()); \
}
#define UNSUPPORTED_ERROR(msg) \
{ \
std::ostringstream err; \
err << "Unsupported Error: " << msg; \
throw std::runtime_error(err.str()); \
}

View File

@ -0,0 +1,22 @@
#pragma once
#include <cstdlib>
#include <cstring>
namespace sys_util {
static bool GetEnvBool(const char* name, bool defval) {
const char* env = std::getenv(name);
if (env == nullptr) {
return defval;
}
if (std::strcmp(env, "true") == 0) {
return true;
}
if (std::strcmp(env, "false") == 0) {
return false;
}
return std::atoi(env) != 0;
}
} // namespace sys_util