mirror of https://github.com/llvm/torch-mlir
Got LTC working until compile (#689)
parent
58338f79a1
commit
c3b20e444c
|
@ -29,7 +29,7 @@ bazel-*
|
||||||
/python/torch_mlir/csrc/backend/LazyLazyIr.h
|
/python/torch_mlir/csrc/backend/LazyLazyIr.h
|
||||||
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
|
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
|
||||||
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
|
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
|
||||||
/python/torch_mlir/csrc/backend/LazyShapeInference.cpp
|
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
|
||||||
/python/torch_mlir/csrc/backend/RegisterLazy.cpp
|
/python/torch_mlir/csrc/backend/RegisterLazy.cpp
|
||||||
|
|
||||||
# Libraries
|
# Libraries
|
||||||
|
|
|
@ -0,0 +1,321 @@
|
||||||
|
import argparse
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import which
|
||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
|
||||||
|
TORCH_DIR = TORCH_MLIR_DIR.parent.joinpath("pytorch")
|
||||||
|
|
||||||
|
sys.path.append(str(TORCH_DIR.joinpath("tools")))
|
||||||
|
|
||||||
|
# PyTorch's LTC backend autogen script
|
||||||
|
import codegen.dest.lazy_ir
|
||||||
|
import codegen.gen_lazy_tensor
|
||||||
|
from codegen.api.lazy import LazyIrSchema
|
||||||
|
from codegen.gen import get_grouped_native_functions, parse_native_yaml
|
||||||
|
from codegen.model import NativeFunctionsGroup
|
||||||
|
|
||||||
|
|
||||||
|
def generate_native_functions(
|
||||||
|
config_path: Path, torch_ops_file: Path, out_file: Path
|
||||||
|
):
|
||||||
|
print("Generating Native Functions Yaml")
|
||||||
|
|
||||||
|
native_yaml_path = TORCH_DIR.joinpath(
|
||||||
|
"aten", "src", "ATen", "native", "native_functions.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_yaml = parse_native_yaml(native_yaml_path)
|
||||||
|
native_functions = parsed_yaml.native_functions
|
||||||
|
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||||
|
|
||||||
|
def get_native_function_name(f):
|
||||||
|
func = f.func if hasattr(f, "func") else f.functional.func
|
||||||
|
return str(func.name)
|
||||||
|
|
||||||
|
aten_funcs = set(map(get_native_function_name, grouped_native_functions))
|
||||||
|
|
||||||
|
with config_path.open() as f:
|
||||||
|
config = yaml.load(f, yaml.CLoader)
|
||||||
|
|
||||||
|
# List of unsupported ops in LTC autogen because of some error
|
||||||
|
blacklist = config.get("blacklist", [])
|
||||||
|
|
||||||
|
# List of supported ops that we don't want to do the full codegen for
|
||||||
|
# primarily view ops
|
||||||
|
supported = config.get("supported", [])
|
||||||
|
|
||||||
|
if which("rg") is not None: # use ripgrep if available as its much faster
|
||||||
|
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
|
||||||
|
else:
|
||||||
|
cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"]
|
||||||
|
|
||||||
|
output = (
|
||||||
|
subprocess.check_output(
|
||||||
|
cmd + [str(torch_ops_file)],
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
.strip()
|
||||||
|
.split(os.linesep)
|
||||||
|
)
|
||||||
|
|
||||||
|
# process ops list
|
||||||
|
ops = []
|
||||||
|
supported_ops = []
|
||||||
|
skipped = []
|
||||||
|
|
||||||
|
for op in output:
|
||||||
|
op = op[6:]
|
||||||
|
opname = op.split(".")[0]
|
||||||
|
|
||||||
|
if opname in blacklist or op in blacklist:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if opname in supported:
|
||||||
|
supported_ops.append(op)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if op not in aten_funcs:
|
||||||
|
skipped.append(op)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ops.append(op)
|
||||||
|
|
||||||
|
opnames = sorted(set(ops))
|
||||||
|
|
||||||
|
# Additional ops to support that are not supported by Torch-MLIR explicitly
|
||||||
|
supported_ops.extend(config.get("additional_ops", []))
|
||||||
|
|
||||||
|
with out_file.open("w") as f:
|
||||||
|
yaml.dump(
|
||||||
|
{
|
||||||
|
"backend": "Lazy",
|
||||||
|
"cpp_namespace": "torch_lazy_tensors",
|
||||||
|
"full_codegen": opnames,
|
||||||
|
"supported": sorted(supported_ops),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
default_flow_style=False,
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
dedent(
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Skipped ops (supported by Torch-MLIR but no equivalent native function)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
+ os.linesep.join(f"# - {op}" for op in sorted(skipped))
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_yaml, grouped_native_functions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
|
||||||
|
lowering_function_type: str = "torch::lazy::MlirFunction"
|
||||||
|
lowering_context_type: str = "torch::lazy::MlirLoweringContext*"
|
||||||
|
lowering_return_type: str = "torch::lazy::MlirOpVector"
|
||||||
|
|
||||||
|
def lowering_body(self, f):
|
||||||
|
func = (
|
||||||
|
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||||
|
)
|
||||||
|
schema = LazyIrSchema(func)
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
UNIMPLEMENTED_ERROR(
|
||||||
|
"'{func}' lowering not yet implemented"
|
||||||
|
);
|
||||||
|
""".rstrip()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_backend(
|
||||||
|
source_yaml: Path,
|
||||||
|
backend_path: Path,
|
||||||
|
parsed_yaml: dict,
|
||||||
|
grouped_native_functions: list,
|
||||||
|
):
|
||||||
|
print("Running Lazy Tensor Autogen")
|
||||||
|
|
||||||
|
# No fallback code allowed
|
||||||
|
def gen_fallback_code(*args, **kwargs):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
|
||||||
|
|
||||||
|
codegen.gen_lazy_tensor.run(
|
||||||
|
backend_name="TorchMlir",
|
||||||
|
source_yaml=str(source_yaml),
|
||||||
|
output_dir=str(backend_path),
|
||||||
|
dry_run=False,
|
||||||
|
impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")),
|
||||||
|
gen_ts_lowerings=False,
|
||||||
|
node_base="torch::lazy::MlirNode",
|
||||||
|
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
|
||||||
|
tensor_class="torch::lazy::LazyTensor",
|
||||||
|
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
|
||||||
|
shape_inference_hdr=str(backend_path.joinpath("LazyShapeInference.h")),
|
||||||
|
lazy_ir_cls=MlirLazyIr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove lazy_tensor_core imports
|
||||||
|
subprocess.check_call(
|
||||||
|
[
|
||||||
|
"sed",
|
||||||
|
"-i",
|
||||||
|
"/lazy_tensor_core/d",
|
||||||
|
str(backend_path.joinpath("LazyNativeFunctions.cpp")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# programmatically check shape inference declarations
|
||||||
|
import re
|
||||||
|
|
||||||
|
sig_re = re.compile(
|
||||||
|
r"std::vector<Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
|
||||||
|
)
|
||||||
|
global_signatures = {}
|
||||||
|
|
||||||
|
def extract_signatures(path):
|
||||||
|
signatures = set()
|
||||||
|
for name, args in sig_re.findall(path.read_text()):
|
||||||
|
signature = re.sub(r"\s+", "", f"{name}({args})")
|
||||||
|
global_signatures[signature] = (name, args)
|
||||||
|
signatures.add(signature)
|
||||||
|
return signatures
|
||||||
|
|
||||||
|
upstream_shape_inference_decls = extract_signatures(
|
||||||
|
TORCH_DIR.joinpath("torch", "csrc", "lazy", "core", "shape_inference.h")
|
||||||
|
)
|
||||||
|
assert len(upstream_shape_inference_decls) > 0
|
||||||
|
shape_inference_decls = extract_signatures(
|
||||||
|
backend_path.joinpath("LazyShapeInference.h")
|
||||||
|
)
|
||||||
|
assert len(shape_inference_decls) > 0
|
||||||
|
shape_inference_defs = extract_signatures(
|
||||||
|
backend_path.joinpath("LazyShapeInference.cpp")
|
||||||
|
)
|
||||||
|
assert len(shape_inference_defs) > 0
|
||||||
|
assert len(shape_inference_decls) > len(shape_inference_defs)
|
||||||
|
|
||||||
|
missing_defs = (
|
||||||
|
shape_inference_decls
|
||||||
|
- upstream_shape_inference_decls
|
||||||
|
- shape_inference_defs
|
||||||
|
)
|
||||||
|
if missing_defs:
|
||||||
|
backend_path.joinpath("GenLazyShapeInference.cpp").write_text(
|
||||||
|
dedent(
|
||||||
|
"""
|
||||||
|
// This file contains autogenerated Lazy Shape Inference placeholders
|
||||||
|
// for ops that dont have a corresponding structured kernel or shape definition
|
||||||
|
|
||||||
|
#include "LazyShapeInference.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
namespace torch {{
|
||||||
|
namespace lazy {{
|
||||||
|
{}
|
||||||
|
}} // namespace lazy
|
||||||
|
}} // namespace torch
|
||||||
|
"""
|
||||||
|
).format(
|
||||||
|
"".join(
|
||||||
|
dedent(
|
||||||
|
f"""
|
||||||
|
std::vector<Shape> {name}({args}) {{
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
for name, args in map(
|
||||||
|
global_signatures.get, sorted(missing_defs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
unnecessary_defs = shape_inference_defs - shape_inference_decls
|
||||||
|
if unnecessary_defs:
|
||||||
|
unnecessary_defs = "\n\t".join(
|
||||||
|
f"{name}({args})"
|
||||||
|
for name, args in map(global_signatures.get, unnecessary_defs)
|
||||||
|
)
|
||||||
|
warnings.warn(
|
||||||
|
f"Unnecessary shape inference definitions found for:\n\t{unnecessary_defs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
script_path = Path(__file__).resolve()
|
||||||
|
config_path = (
|
||||||
|
Path(__file__).resolve().parent.joinpath("autogen_ltc_backend.yaml")
|
||||||
|
)
|
||||||
|
torch_ops_file = TORCH_MLIR_DIR.joinpath(
|
||||||
|
"include",
|
||||||
|
"torch-mlir",
|
||||||
|
"Dialect",
|
||||||
|
"Torch",
|
||||||
|
"IR",
|
||||||
|
"GeneratedTorchOps.td",
|
||||||
|
)
|
||||||
|
assert torch_ops_file.exists()
|
||||||
|
native_functions = TORCH_MLIR_DIR.joinpath(
|
||||||
|
"generated_native_functions.yaml"
|
||||||
|
)
|
||||||
|
backend_path = TORCH_MLIR_DIR.joinpath(
|
||||||
|
"python", "torch_mlir", "csrc", "backend"
|
||||||
|
)
|
||||||
|
assert backend_path.is_dir()
|
||||||
|
|
||||||
|
prev_hash = None
|
||||||
|
hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash")
|
||||||
|
if hash_file.exists():
|
||||||
|
prev_hash = hash_file.read_text().strip()
|
||||||
|
|
||||||
|
m = hashlib.sha256()
|
||||||
|
m.update(script_path.read_bytes())
|
||||||
|
m.update(config_path.read_bytes())
|
||||||
|
m.update(torch_ops_file.read_bytes())
|
||||||
|
if native_functions.exists():
|
||||||
|
m.update(native_functions.read_bytes())
|
||||||
|
|
||||||
|
shape_inference_headers = backend_path.joinpath("LazyShapeInference.h")
|
||||||
|
if shape_inference_headers.exists():
|
||||||
|
m.update(shape_inference_headers.read_bytes())
|
||||||
|
|
||||||
|
shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp")
|
||||||
|
if shape_inference_defs.exists():
|
||||||
|
m.update(shape_inference_defs.read_bytes())
|
||||||
|
|
||||||
|
new_hash = m.hexdigest().strip()
|
||||||
|
|
||||||
|
if args.force or new_hash != prev_hash:
|
||||||
|
hash_file.write_text(new_hash)
|
||||||
|
parsed_yaml, grouped_native_functions = generate_native_functions(
|
||||||
|
config_path, torch_ops_file, native_functions
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_backend(
|
||||||
|
native_functions,
|
||||||
|
backend_path,
|
||||||
|
parsed_yaml,
|
||||||
|
grouped_native_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
main(parser.parse_args())
|
|
@ -0,0 +1,52 @@
|
||||||
|
blacklist:
|
||||||
|
# List of unsupported ops in LTC autogen because of some error
|
||||||
|
- arange # Error: Code below assumes there is at least one tensor arg
|
||||||
|
- contiguous # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
- full # Error: Code below assumes there is at least one tensor arg
|
||||||
|
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
|
||||||
|
- index_put # Error: TODO not sure if there are other valid types to handle here
|
||||||
|
- index_put_ # Error: TODO not sure if there are other valid types to handle here
|
||||||
|
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
|
||||||
|
- ones # Error: Code below assumes there is at least one tensor arg
|
||||||
|
- ones_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
- resize_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
- stack # Error: TODO not sure if there are other valid types to handle here
|
||||||
|
- to.dtype # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
- to.other # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
- uniform_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
- zeros # Error: Code below assumes there is at least one tensor arg
|
||||||
|
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||||
|
|
||||||
|
# Additional ops which autogen is supported for but don't compile yet
|
||||||
|
- item
|
||||||
|
- size
|
||||||
|
- where
|
||||||
|
- copy_
|
||||||
|
- _to_copy
|
||||||
|
- log_softmax # Not inherently differentiable. Needs to be decomposed.
|
||||||
|
- linear # Not inherently differentiable. Needs to be decomposed.
|
||||||
|
|
||||||
|
# List of supported ops that we don't want to do the full codegen for
|
||||||
|
# primarily view ops
|
||||||
|
supported:
|
||||||
|
# - bernoulli
|
||||||
|
# - bernoulli_
|
||||||
|
- cat
|
||||||
|
- clone
|
||||||
|
- empty
|
||||||
|
- expand
|
||||||
|
- fill_
|
||||||
|
# - native_batch_norm_backward
|
||||||
|
- native_batch_norm
|
||||||
|
- permute
|
||||||
|
- repeat
|
||||||
|
- squeeze
|
||||||
|
- t
|
||||||
|
- unsqueeze
|
||||||
|
- view
|
||||||
|
|
||||||
|
additional_ops:
|
||||||
|
# Additional ops to support that are not supported by Torch-MLIR explicitly
|
||||||
|
- _copy_from
|
||||||
|
- _copy_from_and_resize
|
|
@ -0,0 +1,4 @@
|
||||||
|
BasedOnStyle: LLVM
|
||||||
|
AlignAfterOpenBracket: AlwaysBreak # BlockIndent
|
||||||
|
PointerAlignment: Left
|
||||||
|
ReflowComments: false
|
|
@ -25,9 +25,11 @@ add_library(torch_mlir_ltc_backend SHARED
|
||||||
backend/backend_impl.cpp
|
backend/backend_impl.cpp
|
||||||
backend/LazyNativeFunctions.cpp
|
backend/LazyNativeFunctions.cpp
|
||||||
backend/LazyShapeInference.cpp
|
backend/LazyShapeInference.cpp
|
||||||
|
backend/GenLazyShapeInference.cpp
|
||||||
backend/mlir_lowering_context.cpp
|
backend/mlir_lowering_context.cpp
|
||||||
backend/mlir_node.cpp
|
backend/mlir_node.cpp
|
||||||
backend/RegisterLazy.cpp
|
backend/RegisterLazy.cpp
|
||||||
|
tensor_aten_ops.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(torch_mlir_ltc_backend
|
target_link_libraries(torch_mlir_ltc_backend
|
||||||
|
@ -40,10 +42,10 @@ target_link_libraries(torch_mlir_ltc_backend
|
||||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
|
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
|
||||||
set_target_properties(torch_mlir_ltc_backend PROPERTIES
|
set_target_properties(torch_mlir_ltc_backend PROPERTIES
|
||||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/"
|
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/"
|
||||||
OUTPUT_NAME _MLIR_LTC
|
OUTPUT_NAME lib_mlir_ltc
|
||||||
PREFIX "${PYTHON_MODULE_PREFIX}"
|
PREFIX ""
|
||||||
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
SUFFIX ".so"
|
||||||
CXX_VISIBILITY_PRESET "hidden"
|
CXX_VISIBILITY_PRESET "hidden"
|
||||||
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
|
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
|
||||||
|
LINK_FLAGS "-rdynamic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Torch-MLIR Lazy Tensor Core Backend
|
#Torch - MLIR Lazy Tensor Core Backend
|
||||||
|
|
||||||
Contained within this directory are the components that implements the
|
Contained within this directory are the components that implements the
|
||||||
Torch-MLIR LTC backend.
|
Torch-MLIR LTC backend.
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===- LazyShapeInference.cpp ---------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "LazyShapeInference.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
std::vector<Shape> compute_shape_detach(const at::Tensor& self) {
|
||||||
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,89 @@
|
||||||
|
//===- LazyShapeInference.h -----------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/Tensor.h>
|
||||||
|
#include <c10/core/ScalarType.h>
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
#include <torch/csrc/lazy/core/ir.h>
|
||||||
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
#include <torch/csrc/lazy/core/shape_inference.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps, bool cudnn_enabled);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, const at::Tensor & p, c10::optional<at::Generator> generator);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, const c10::optional<at::Tensor> & weights, int64_t minlength);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_detach(const at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_dropout(const at::Tensor & input, double p, bool train);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_expand_as(const at::Tensor & self, const at::Tensor & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_floor_divide(const at::Tensor & self, const at::Scalar & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_fmod(const at::Tensor & self, const at::Scalar & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_full_like(const at::Tensor & self, const at::Scalar & fill_value, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_hardswish(const at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps, bool cudnn_enable);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_log_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_matmul(const at::Tensor & self, const at::Tensor & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_max(const at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_relu(const at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_relu_(at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_slice(const at::Tensor & self, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_square(const at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_std(const at::Tensor & self, bool unbiased);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_sub_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
|
||||||
|
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,58 @@
|
||||||
|
//===- aten_eager_fallback.cpp --------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.cpp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||||
|
#include <torch/csrc/lazy/core/metrics.h>
|
||||||
|
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
#include "aten_eager_fallback.h"
|
||||||
|
|
||||||
|
namespace torch_lazy_tensors {
|
||||||
|
|
||||||
|
static std::unordered_map<std::string, ::torch::lazy::Counter*>
|
||||||
|
_eager_fallback_counters;
|
||||||
|
|
||||||
|
bool force_eager_fallback(c10::Symbol op) {
|
||||||
|
static char* force_str = std::getenv("LTC_FORCE_FALLBACK");
|
||||||
|
if (force_str != nullptr) {
|
||||||
|
static auto force_sym = c10::Symbol::fromQualString(std::string(force_str));
|
||||||
|
if (op == force_sym) {
|
||||||
|
std::cout << "MLIR force_eager_fallback(" << force_str << "): true"
|
||||||
|
<< std::endl;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << "MLIR force_eager_fallback(" << op.toQualString() << "): false"
|
||||||
|
<< std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ltc_eager_fallback(
|
||||||
|
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||||
|
const auto name = c10::toString(op.operator_name());
|
||||||
|
UNSUPPORTED_ERROR(
|
||||||
|
"MLIR ltc_eager_fallback is not supported for op: " << name);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<void(void)> register_mlir_ltc_eager_fallback;
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(_, Lazy, m) {
|
||||||
|
register_mlir_ltc_eager_fallback = [&]() {
|
||||||
|
m.fallback(
|
||||||
|
torch::CppFunction::makeFromBoxedFunction<<c_eager_fallback>());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch_lazy_tensors
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===- aten_eager_fallback.h ----------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Facilitates eager fallback behaviour
|
||||||
|
//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.h
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/native/CPUFallback.h>
|
||||||
|
|
||||||
|
namespace torch_lazy_tensors {
|
||||||
|
|
||||||
|
bool force_eager_fallback(c10::Symbol op);
|
||||||
|
void ltc_eager_fallback(
|
||||||
|
const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||||
|
|
||||||
|
extern TORCH_API std::function<void(void)> register_mlir_ltc_eager_fallback;
|
||||||
|
|
||||||
|
} // namespace torch_lazy_tensors
|
|
@ -0,0 +1,399 @@
|
||||||
|
//===- aten_ltc_mlir_type.cpp ---------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_ltc_ts_type.cpp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <ATen/Operators.h>
|
||||||
|
#include <ATen/native/BinaryOps.h>
|
||||||
|
#include <ATen/native/CPUFallback.h>
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <ATen/ops/result_type.h>
|
||||||
|
#include <torch/csrc/lazy/core/helpers.h>
|
||||||
|
#include <torch/csrc/lazy/core/metrics.h>
|
||||||
|
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||||
|
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
#include "ATen/MetaFunctions.h"
|
||||||
|
#include <torch/csrc/lazy/core/tensor_impl.h>
|
||||||
|
|
||||||
|
#include "../tensor_aten_ops.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
#include "../utils/sys_utils.h"
|
||||||
|
#include "LazyNativeFunctions.h"
|
||||||
|
#include "LazyShapeInference.h"
|
||||||
|
|
||||||
|
namespace torch_lazy_tensors {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) {
|
||||||
|
CHECK(type1 != at::kBool || type2 != at::kBool)
|
||||||
|
<< "Subtraction, the `-` operator, with two bool tensors is not "
|
||||||
|
"supported. Use the `^` or `logical_xor()` operator instead.";
|
||||||
|
CHECK(type1 != at::kBool && type2 != at::kBool)
|
||||||
|
<< "Subtraction, the `-` operator, with a bool tensor is not "
|
||||||
|
"supported. If you are trying to invert a mask, use the `~` or "
|
||||||
|
"`logical_not()` operator instead.";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr>
|
||||||
|
GetBinaryOperands(const at::Tensor& self, const at::Tensor& other) {
|
||||||
|
torch::lazy::LazyTensorPtr self_tensor;
|
||||||
|
torch::lazy::LazyTensorPtr other_tensor;
|
||||||
|
auto self_xtensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
if (!self_xtensor) {
|
||||||
|
other_tensor = torch::lazy::TryGetLtcTensor(other);
|
||||||
|
self_tensor = GetOrCreateLtcTensor(self, other_tensor->GetDevice());
|
||||||
|
} else {
|
||||||
|
self_tensor = self_xtensor;
|
||||||
|
other_tensor = GetOrCreateLtcTensor(other, self_tensor->GetDevice());
|
||||||
|
}
|
||||||
|
return std::pair<torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr>(
|
||||||
|
self_tensor, other_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename B>
|
||||||
|
at::Tensor
|
||||||
|
DoBinaryOp(const at::Tensor& self, const at::Tensor& other, const B& bin_op) {
|
||||||
|
at::ScalarType dtype = at::result_type(self, other);
|
||||||
|
std::pair<torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr> operands =
|
||||||
|
GetBinaryOperands(
|
||||||
|
torch::lazy::UnwrapNumber(self, dtype),
|
||||||
|
torch::lazy::UnwrapNumber(other, dtype));
|
||||||
|
torch::lazy::LazyTensorPtr result = bin_op(operands.first, operands.second);
|
||||||
|
return torch::lazy::CreateAtenFromLtcTensor(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename B>
|
||||||
|
at::Tensor
|
||||||
|
DoBinaryOp(const at::Tensor& self, const at::Scalar& other, const B& bin_op) {
|
||||||
|
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::GetLtcTensor(self);
|
||||||
|
torch::lazy::LazyTensorPtr result = bin_op(self_tensor, other);
|
||||||
|
return torch::lazy::CreateAtenFromLtcTensor(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor subtensor(const at::Tensor& tensor, int dim, int groups, int g) {
|
||||||
|
if (!tensor.defined()) {
|
||||||
|
return at::Tensor();
|
||||||
|
}
|
||||||
|
int64_t n = tensor.sizes()[dim] / groups;
|
||||||
|
return tensor.narrow(dim, n * g, n).contiguous();
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor CreateLtcTensor(
|
||||||
|
const at::Tensor& tensor,
|
||||||
|
const c10::optional<torch::lazy::BackendDevice>& device) {
|
||||||
|
if (tensor.defined() && device) {
|
||||||
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
torch::lazy::LazyTensor::Create(tensor, *device));
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::optional<torch::lazy::BackendDevice>
|
||||||
|
GetLtcDevice(const c10::optional<c10::Device>& device) {
|
||||||
|
if (!device) {
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
if (device->type() != at::kLazy) {
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
return torch::lazy::atenDeviceToBackendDevice(*device);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// at::Tensor LazyNativeFunctions::bernoulli(
|
||||||
|
// const at::Tensor& self, c10::optional<at::Generator> generator) {
|
||||||
|
// TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
// if (generator.has_value() && generator->defined()) {
|
||||||
|
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value");
|
||||||
|
// }
|
||||||
|
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
|
||||||
|
// UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// // return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
// // lazy_tensor_aten_ops::bernoulli(self_tensor));
|
||||||
|
// }
|
||||||
|
|
||||||
|
// at::Tensor& LazyNativeFunctions::bernoulli_(
|
||||||
|
// at::Tensor& self, double p, c10::optional<at::Generator> generator) {
|
||||||
|
// TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
// if (generator.has_value() && generator->defined()) {
|
||||||
|
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value");
|
||||||
|
// }
|
||||||
|
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
|
||||||
|
// UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// // lazy_tensor_aten_ops::bernoulli_(self_tensor, p);
|
||||||
|
// // return self;
|
||||||
|
// }
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
auto lazy_tensors = torch::lazy::GetLtcTensors(tensors);
|
||||||
|
std::vector<torch::lazy::Value> values;
|
||||||
|
values.reserve(lazy_tensors.size());
|
||||||
|
for (auto& tensor : lazy_tensors) {
|
||||||
|
values.emplace_back(tensor->GetIrValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shapes = torch::lazy::compute_shape_cat(tensors, dim);
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// auto node =
|
||||||
|
// torch::lazy::MakeNode<ir::ops::Cat>(values, dim, std::move(shapes));
|
||||||
|
// auto result = torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
// torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 0),
|
||||||
|
// lazy_tensors[0]->GetDevice()));
|
||||||
|
// return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::clone(
|
||||||
|
const at::Tensor& self, c10::optional<at::MemoryFormat> memory_format) {
|
||||||
|
auto self_lt = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::_copy_from(
|
||||||
|
const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
|
||||||
|
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
if (!self_tensor) {
|
||||||
|
// providing a new 'eager' value (self) for an existing lazy tensor (dst)
|
||||||
|
static bool sync_update =
|
||||||
|
sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true);
|
||||||
|
CHECK(dst_tensor);
|
||||||
|
dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
|
||||||
|
} else if (!dst_tensor) {
|
||||||
|
// materializing a lazy tensor (self) and copying its value into eager
|
||||||
|
// tensor (dst)
|
||||||
|
// detached=false lets us skip a copy in `ToTensor`, which should be safe
|
||||||
|
// becuase we are only going to use the tensor for dst.copy_()
|
||||||
|
CHECK(self_tensor);
|
||||||
|
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/false);
|
||||||
|
at::Tensor typed_tensor =
|
||||||
|
torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
|
||||||
|
dst.resize_as_(typed_tensor).copy_(typed_tensor);
|
||||||
|
} else {
|
||||||
|
// Copying one lazy tensor to another
|
||||||
|
if (!dst_tensor->CurrentIrValue()) {
|
||||||
|
// if dest is not backed by IR (e.g. result of some lazy operation),
|
||||||
|
// then it should have at::Tensor data backing it instead
|
||||||
|
auto dst_tensor_data = dst_tensor->CurrentTensorData();
|
||||||
|
CHECK(dst_tensor_data);
|
||||||
|
auto src_tensor_data = self_tensor->CurrentTensorData();
|
||||||
|
if (src_tensor_data) {
|
||||||
|
// both src/dst are simply backed by at::Tensor data, no IR- do a
|
||||||
|
// straightforward copy
|
||||||
|
dst_tensor_data->copy_(*src_tensor_data);
|
||||||
|
} else {
|
||||||
|
// src needs to be materialized before its result can be used for a copy
|
||||||
|
// into dst
|
||||||
|
// since we use the src tensor only for making a copy, we don't need to
|
||||||
|
// detach it
|
||||||
|
// note: it would be even more efficient if we could cause ToTensor to
|
||||||
|
// materialize the
|
||||||
|
// value directly into dst's buffer (that would need to be detached
|
||||||
|
// though).
|
||||||
|
dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lazy_tensor_aten_ops::copy_(dst_tensor, self_tensor);
|
||||||
|
auto* impl =
|
||||||
|
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
|
||||||
|
impl->set_tensor(dst_tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::_copy_from_and_resize(
|
||||||
|
const at::Tensor& self, const at::Tensor& dst) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
|
||||||
|
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
if (!self_tensor) {
|
||||||
|
CHECK(dst_tensor);
|
||||||
|
dst_tensor->UpdateFromTensorOut(self);
|
||||||
|
} else if (!dst_tensor) {
|
||||||
|
CHECK(self_tensor);
|
||||||
|
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true);
|
||||||
|
at::Tensor typed_tensor =
|
||||||
|
torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
|
||||||
|
dst.resize_as_(typed_tensor).copy_(typed_tensor);
|
||||||
|
} else {
|
||||||
|
// at this point we know dst is a lazy tensor
|
||||||
|
auto* dest_impl =
|
||||||
|
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
|
||||||
|
dest_impl->tensor()->UpdateFromTensorOut(self_tensor);
|
||||||
|
dest_impl->force_refresh_sizes();
|
||||||
|
}
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::empty(
|
||||||
|
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||||
|
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory,
|
||||||
|
c10::optional<at::MemoryFormat> memory_format) {
|
||||||
|
const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
|
||||||
|
at::TensorOptions options = at::TensorOptions()
|
||||||
|
.device(c10::Device(device_type))
|
||||||
|
.layout(layout)
|
||||||
|
.pinned_memory(pin_memory)
|
||||||
|
.dtype(dtype);
|
||||||
|
auto x_result = at::empty(size, options, memory_format);
|
||||||
|
return CreateLtcTensor(x_result, GetLtcDevice(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::expand(
|
||||||
|
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::expand(
|
||||||
|
// torch::lazy::TryGetLtcTensor(self), size.vec()));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor&
|
||||||
|
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
lazy_tensor_aten_ops::fill_(self_tensor, value);
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
LazyNativeFunctions::native_batch_norm(
|
||||||
|
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
||||||
|
const c10::optional<at::Tensor>& bias,
|
||||||
|
const c10::optional<at::Tensor>& running_mean,
|
||||||
|
const c10::optional<at::Tensor>& running_var, bool training,
|
||||||
|
double momentum, double eps) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
torch::lazy::LazyTensorPtr input_tensor = torch::lazy::TryGetLtcTensor(input);
|
||||||
|
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
|
||||||
|
torch::lazy::LazyTensorPtr running_mean_tensor =
|
||||||
|
GetOrCreateLtcTensor(running_mean, device);
|
||||||
|
torch::lazy::LazyTensorPtr running_var_tensor =
|
||||||
|
GetOrCreateLtcTensor(running_var, device);
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// auto outputs = lazy_tensor_aten_ops::ts_native_batch_norm(
|
||||||
|
// torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight,
|
||||||
|
// device),
|
||||||
|
// GetOrCreateLtcTensor(bias, device), running_mean_tensor,
|
||||||
|
// running_var_tensor, training, momentum, eps);
|
||||||
|
// return
|
||||||
|
// std::make_tuple(torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
|
||||||
|
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
|
||||||
|
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
// LazyNativeFunctions::native_batch_norm_backward(
|
||||||
|
// const at::Tensor& grad_out, const at::Tensor& input,
|
||||||
|
// const c10::optional<at::Tensor>& weight,
|
||||||
|
// const c10::optional<at::Tensor>& running_mean,
|
||||||
|
// const c10::optional<at::Tensor>& running_var,
|
||||||
|
// const c10::optional<at::Tensor>& save_mean,
|
||||||
|
// const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
|
||||||
|
// std::array<bool, 3> output_mask) {
|
||||||
|
// TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
// torch::lazy::LazyTensor grad_out_tensor =
|
||||||
|
// torch::lazy::TryGetLtcTensor(grad_out);
|
||||||
|
// const torch::lazy::BackendDevice& device = grad_out_tensor.GetDevice();
|
||||||
|
// torch::lazy::LazyTensor null_tensor;
|
||||||
|
// bool running_stats = running_mean && running_mean->defined();
|
||||||
|
// CHECK_EQ(running_var && running_var->defined(), running_stats);
|
||||||
|
// UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// // auto gradients = lazy_tensor_aten_ops::ts_native_batch_norm_backward(
|
||||||
|
// // torch::lazy::TryGetLtcTensor(grad_out),
|
||||||
|
// torch::lazy::TryGetLtcTensor(input),
|
||||||
|
// // GetOrCreateLtcTensor(weight, device),
|
||||||
|
// // running_stats ? GetOrCreateLtcTensor(running_mean, device)
|
||||||
|
// // : null_tensor,
|
||||||
|
// // running_stats ? GetOrCreateLtcTensor(running_var, device)
|
||||||
|
// // : null_tensor,
|
||||||
|
// // GetOrCreateLtcTensor(save_mean, device),
|
||||||
|
// // GetOrCreateLtcTensor(save_invstd, device), train, eps,
|
||||||
|
// // output_mask);
|
||||||
|
// // at::Tensor undefined;
|
||||||
|
// // return std::make_tuple(
|
||||||
|
// // output_mask[0] ?
|
||||||
|
// torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
|
||||||
|
// // : undefined,
|
||||||
|
// // output_mask[1] ?
|
||||||
|
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
|
||||||
|
// // : undefined,
|
||||||
|
// // output_mask[2] ?
|
||||||
|
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
|
||||||
|
// // : undefined);
|
||||||
|
// }
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::permute(
|
||||||
|
// self_tensor, torch::lazy::ToI64Vector(dims)));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
LazyNativeFunctions::repeat(const at::Tensor& self, at::IntArrayRef repeats) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::repeat(
|
||||||
|
// torch::lazy::TryGetLtcTensor(self),
|
||||||
|
// torch::lazy::ToI64Vector(repeats)));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self)));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self),
|
||||||
|
// dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::transpose(
|
||||||
|
torch::lazy::TryGetLtcTensor(self), 0, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
// return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
// lazy_tensor_aten_ops::unsqueeze(torch::lazy::TryGetLtcTensor(self),
|
||||||
|
// dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
lazy_tensor_aten_ops::view(self_tensor, torch::lazy::ToI64Vector(size)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitializeAtenBindings() {}
|
||||||
|
|
||||||
|
} // namespace torch_lazy_tensors
|
|
@ -0,0 +1,157 @@
|
||||||
|
//===- backend_impl.cpp ---------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.cpp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
|
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||||
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
|
#include "../utils/debug.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
#include "backend_impl.h"
|
||||||
|
#include "mlir_lowering_context.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
MlirBackendData::MlirBackendData(BackendDevice device, Shape shape)
|
||||||
|
: BackendData(device, shape) {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
auto info = std::make_shared<MlirBackendData::Info>();
|
||||||
|
SetInfo(info);
|
||||||
|
}
|
||||||
|
MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device)
|
||||||
|
: BackendData(device, Shape(scalar.type(), {})) {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
auto info = std::make_shared<MlirBackendData::Info>(scalar);
|
||||||
|
SetInfo(info);
|
||||||
|
}
|
||||||
|
MlirBackendData::MlirBackendData(
|
||||||
|
const at::Tensor& tensor, BackendDevice device, Shape shape)
|
||||||
|
: BackendData(device, shape) {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
auto info = std::make_shared<MlirBackendData::Info>(tensor);
|
||||||
|
SetInfo(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
BackendData::Handle MlirBackendData::GetHandle() {
|
||||||
|
return reinterpret_cast<int64_t>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirBackendData::Assign(const BackendData& data) {
|
||||||
|
MlirBackendData::Info* info =
|
||||||
|
dynamic_cast<MlirBackendData::Info*>(data.info());
|
||||||
|
TORCH_CHECK(
|
||||||
|
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
|
||||||
|
auto new_info = std::make_shared<MlirBackendData::Info>(*info);
|
||||||
|
SetInfo(new_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MlirBackendData::HasValue() const { return bool(info()); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialization/Teardown
|
||||||
|
* */
|
||||||
|
void MlirBackendImpl::PrepareToExit() const {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data Transfer
|
||||||
|
* */
|
||||||
|
|
||||||
|
BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor(
|
||||||
|
const at::Tensor& tensor, const Shape& shape,
|
||||||
|
const BackendDevice& device) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return std::make_shared<MlirBackendData>(tensor, device, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar(
|
||||||
|
const at::Scalar& scalar, const BackendDevice& device) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return std::make_shared<MlirBackendData>(scalar, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
BackendDataPtr MlirBackendImpl::CreateDataPlaceholder(
|
||||||
|
const BackendDevice& device, const Shape& shape) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return std::make_shared<MlirBackendData>(device, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
|
||||||
|
const BackendDataPtr data,
|
||||||
|
c10::optional<at::ScalarType> logical_scalar_type) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
MlirBackendData::Info* info =
|
||||||
|
dynamic_cast<MlirBackendData::Info*>(data->info());
|
||||||
|
TORCH_CHECK(
|
||||||
|
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
|
||||||
|
return info->tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lowering, Compilation, Execution
|
||||||
|
* */
|
||||||
|
|
||||||
|
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
|
||||||
|
const std::string& name, BackendDevice device,
|
||||||
|
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return std::make_unique<MlirLoweringContext>(
|
||||||
|
name, std::forward<BackendDevice>(device),
|
||||||
|
std::forward<c10::ArrayRef<Node*>>(post_order),
|
||||||
|
std::forward<Util::EmissionMap>(emit_status));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
|
||||||
|
const std::string& name, BackendDevice device) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return std::make_unique<MlirLoweringContext>(
|
||||||
|
name, std::forward<BackendDevice>(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Device Configuration
|
||||||
|
* */
|
||||||
|
|
||||||
|
// Set or get the default device type.
|
||||||
|
// For backends used with virtual c10:: Devices, this configures what real
|
||||||
|
// device type the backend should use, and matters if the backend supports
|
||||||
|
// more than one type of real device.
|
||||||
|
|
||||||
|
// Specify which aten device should be used for eager fallback
|
||||||
|
// may change depending on current 'Default' DeviceType
|
||||||
|
at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return at::DeviceType::CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query all available backend devices
|
||||||
|
std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return {
|
||||||
|
GetBackendDevice(c10::Device(c10::kLazy, 0)),
|
||||||
|
GetBackendDevice(c10::Device(c10::kCPU, 0))};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map a particular c10:: device to a concrete backend device
|
||||||
|
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
|
||||||
|
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
|
||||||
|
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
||||||
|
// through a mode, in which case these APIs should still work, but should be
|
||||||
|
// identity mappings.
|
||||||
|
BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
return BackendDevice(GetDefaultDeviceType(), device.index());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -23,23 +23,37 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class MlirBackendData : public torch::lazy::BackendData {
|
class TORCH_API MlirBackendData : public BackendData {
|
||||||
public:
|
public:
|
||||||
struct Info;
|
struct Info : public BackendData::Info {
|
||||||
|
at::Tensor tensor;
|
||||||
|
c10::optional<at::Scalar> scalar;
|
||||||
|
bool requires_grad;
|
||||||
|
|
||||||
MlirBackendData(torch::lazy::BackendDevice device, torch::lazy::Shape shape);
|
Info() {}
|
||||||
MlirBackendData(const at::Scalar& scalar, torch::lazy::BackendDevice device);
|
Info(const Info& other)
|
||||||
MlirBackendData(const at::Tensor& tensor, torch::lazy::BackendDevice device, torch::lazy::Shape shape);
|
: tensor{other.tensor}, scalar{other.scalar},
|
||||||
|
requires_grad{other.requires_grad} {}
|
||||||
|
Info(const at::Tensor& tensor)
|
||||||
|
: tensor{tensor}, requires_grad{tensor.requires_grad()} {}
|
||||||
|
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
|
||||||
|
};
|
||||||
|
|
||||||
virtual torch::lazy::BackendData::Handle GetHandle() override;
|
MlirBackendData(BackendDevice device, Shape shape);
|
||||||
|
MlirBackendData(const at::Scalar& scalar, BackendDevice device);
|
||||||
|
MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
|
||||||
|
|
||||||
virtual void Assign(const torch::lazy::BackendData& data) override;
|
virtual BackendData::Handle GetHandle() override;
|
||||||
|
|
||||||
|
virtual void Assign(const BackendData& data) override;
|
||||||
|
|
||||||
virtual bool HasValue() const override;
|
virtual bool HasValue() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MlirBackendImpl : public torch::lazy::BackendImplInterface {
|
class TORCH_API MlirBackendImpl : public BackendImplInterface {
|
||||||
public:
|
public:
|
||||||
|
virtual ~MlirBackendImpl() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialization/Teardown
|
* Initialization/Teardown
|
||||||
* */
|
* */
|
||||||
|
@ -54,54 +68,45 @@ public:
|
||||||
* Data Transfer
|
* Data Transfer
|
||||||
* */
|
* */
|
||||||
|
|
||||||
virtual torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
|
virtual BackendDataPtr MakeComputationDataFromTensor(
|
||||||
const at::Tensor& tensor,
|
const at::Tensor& tensor, const Shape& shape,
|
||||||
const torch::lazy::Shape& shape,
|
const BackendDevice& device) const override;
|
||||||
const torch::lazy::BackendDevice& device
|
|
||||||
) const override;
|
|
||||||
|
|
||||||
virtual torch::lazy::BackendDataPtr MakeComputationDataFromScalar(
|
virtual BackendDataPtr MakeComputationDataFromScalar(
|
||||||
const at::Scalar& scalar,
|
const at::Scalar& scalar, const BackendDevice& device) const override;
|
||||||
const torch::lazy::BackendDevice& device
|
|
||||||
) const override;
|
|
||||||
|
|
||||||
virtual torch::lazy::BackendDataPtr CreateDataPlaceholder(
|
virtual BackendDataPtr CreateDataPlaceholder(
|
||||||
const torch::lazy::BackendDevice& device, const torch::lazy::Shape& shape
|
const BackendDevice& device, const Shape& shape) const override;
|
||||||
) const override;
|
|
||||||
|
|
||||||
virtual at::Tensor MakeTensorFromComputationData(
|
virtual at::Tensor MakeTensorFromComputationData(
|
||||||
const torch::lazy::BackendDataPtr data,
|
const BackendDataPtr data,
|
||||||
c10::optional<at::ScalarType> logical_scalar_type
|
c10::optional<at::ScalarType> logical_scalar_type) const override;
|
||||||
) const override;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lowering, Compilation, Execution
|
* Lowering, Compilation, Execution
|
||||||
* */
|
* */
|
||||||
|
|
||||||
virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
|
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
||||||
const std::string& name,
|
const std::string& name, BackendDevice device,
|
||||||
torch::lazy::BackendDevice device,
|
c10::ArrayRef<Node*> post_order,
|
||||||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
Util::EmissionMap emit_status) const override;
|
||||||
torch::lazy::Util::EmissionMap emit_status
|
|
||||||
) const override;
|
|
||||||
|
|
||||||
virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
|
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
||||||
const std::string& name, torch::lazy::BackendDevice device
|
const std::string& name, BackendDevice device) const override;
|
||||||
) const override;
|
|
||||||
|
|
||||||
// TODO(whc) need to keep this?
|
// TODO(whc) need to keep this?
|
||||||
// virtual std::vector<std::string> GetCompilationDevices(
|
// virtual std::vector<std::string> GetCompilationDevices(
|
||||||
// const std::string& device, c10::ArrayRef<std::string> devices
|
// const std::string& device, c10::ArrayRef<std::string> devices
|
||||||
// ) const = 0;
|
// ) const = 0;
|
||||||
|
|
||||||
// virtual std::vector<torch::lazy::ComputationPtr> Compile(
|
// virtual std::vector<ComputationPtr> Compile(
|
||||||
// std::vector<torch::lazy::ComputationPtr> instances
|
// std::vector<ComputationPtr> instances
|
||||||
// ) const = 0;
|
// ) const = 0;
|
||||||
|
|
||||||
// virtual std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
|
// virtual std::vector<BackendDataPtr> ExecuteComputation(
|
||||||
// torch::lazy::Computation& computation,
|
// Computation& computation,
|
||||||
// c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
|
// c10::ArrayRef<BackendDataPtr> arguments,
|
||||||
// const torch::lazy::BackendDevice& device
|
// const BackendDevice& device
|
||||||
// ) const = 0;
|
// ) const = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -113,16 +118,16 @@ public:
|
||||||
// device type the backend should use, and matters if the backend supports
|
// device type the backend should use, and matters if the backend supports
|
||||||
// more than one type of real device.
|
// more than one type of real device.
|
||||||
|
|
||||||
// virtual std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const = 0;
|
// virtual std::shared_ptr<BackendDeviceType> GetDefaultDeviceType() const =
|
||||||
|
// 0;
|
||||||
// virtual void SetDefaultDeviceType(std::string device_type) = 0;
|
// virtual void SetDefaultDeviceType(std::string device_type) = 0;
|
||||||
|
|
||||||
// Specify which aten device should be used for eager fallback
|
// Specify which aten device should be used for eager fallback
|
||||||
// may change depending on current 'Default' DeviceType
|
// may change depending on current 'Default' DeviceType
|
||||||
virtual at::DeviceType EagerFallbackDeviceType() const override;
|
virtual at::DeviceType EagerFallbackDeviceType() const override;
|
||||||
|
|
||||||
|
|
||||||
// Query all available backend devices
|
// Query all available backend devices
|
||||||
virtual std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override;
|
virtual std::vector<BackendDevice> GetBackendDevices() const override;
|
||||||
|
|
||||||
// Map a particular c10:: device to a concrete backend device
|
// Map a particular c10:: device to a concrete backend device
|
||||||
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
|
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
|
||||||
|
@ -130,8 +135,7 @@ public:
|
||||||
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
||||||
// through a mode, in which case these APIs should still work, but should be
|
// through a mode, in which case these APIs should still work, but should be
|
||||||
// identity mappings.
|
// identity mappings.
|
||||||
virtual torch::lazy::BackendDevice GetBackendDevice(c10::Device device) const override;
|
virtual BackendDevice GetBackendDevice(c10::Device device) const override;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Debug/Metrics
|
* Debug/Metrics
|
||||||
|
@ -142,10 +146,9 @@ public:
|
||||||
// virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
|
// virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
|
||||||
|
|
||||||
// virtual std::string GetComputationBackendText(
|
// virtual std::string GetComputationBackendText(
|
||||||
// const torch::lazy::ComputationPtr computation
|
// const ComputationPtr computation
|
||||||
// ) const = 0;
|
// ) const = 0;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // lazy
|
} // namespace lazy
|
||||||
} // torch
|
} // namespace torch
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
//===- mlir_lowering_context.cpp ------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/ts_lowering_context.cpp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
#include "mlir_lowering_context.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
MlirLoweringContext::MlirLoweringContext(
|
||||||
|
const std::string& name, BackendDevice device)
|
||||||
|
: LoweringContext(name, std::forward<BackendDevice>(device)) {}
|
||||||
|
|
||||||
|
MlirLoweringContext::MlirLoweringContext(
|
||||||
|
const std::string& name, BackendDevice device,
|
||||||
|
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
||||||
|
: LoweringContext(
|
||||||
|
name, std::forward<BackendDevice>(device),
|
||||||
|
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||||
|
std::forward<Util::EmissionMap>(emit_status)) {}
|
||||||
|
|
||||||
|
int MlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
|
||||||
|
|
||||||
|
const std::vector<torch::lazy::Shape>&
|
||||||
|
MlirComputation::parameter_shapes() const {
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::string>& MlirComputation::parameter_names() const {
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
const torch::lazy::Shape& MlirComputation::result_shape() const {
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the shape of the result tuple component, given by index.
|
||||||
|
torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const {
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds the given output as a component of the result tuple and returns its
|
||||||
|
// assigned position within the tuple.
|
||||||
|
size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
|
||||||
|
const torch::lazy::Node* node;
|
||||||
|
auto it = emitted_outputs_.find(output);
|
||||||
|
if (it == emitted_outputs_.end()) {
|
||||||
|
node = output.node;
|
||||||
|
|
||||||
|
auto post_order = Util::ComputePostOrder(node, &emit_status_);
|
||||||
|
for (auto po_node : post_order) {
|
||||||
|
// TODO: uncomment after lowering is implemented
|
||||||
|
// bool ok = lowering_->Lower(node);
|
||||||
|
// TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
|
||||||
|
}
|
||||||
|
emitted_outputs_[output] = node;
|
||||||
|
} else {
|
||||||
|
node = it->second;
|
||||||
|
}
|
||||||
|
result_tuple_.emplace_back(node);
|
||||||
|
return result_tuple_.size() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associates the given output with the input parameter of the given index and
|
||||||
|
// shape. Only used for the operator-by-operator execution, mostly for
|
||||||
|
// debugging purposes.
|
||||||
|
void MlirLoweringContext::AddParameter(
|
||||||
|
const torch::lazy::Output& output, size_t index,
|
||||||
|
const torch::lazy::Shape& shape, const std::string& name) {
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the computation capturing all the operations created with the
|
||||||
|
// embedded builder (returned by the builder() API).
|
||||||
|
ComputationPtr MlirLoweringContext::Build() {
|
||||||
|
for (const torch::lazy::Node* output : result_tuple_) {
|
||||||
|
}
|
||||||
|
return std::make_shared<MlirComputation>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -10,34 +10,33 @@
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class MlirComputation : public torch::lazy::Computation {
|
class TORCH_API MlirComputation : public torch::lazy::Computation {
|
||||||
public:
|
public:
|
||||||
int parameters_size() const override;
|
int parameters_size() const override;
|
||||||
|
|
||||||
virtual const std::vector<torch::lazy::Shape>& parameter_shapes() const override;
|
virtual const std::vector<torch::lazy::Shape>&
|
||||||
|
parameter_shapes() const override;
|
||||||
|
|
||||||
virtual const std::vector<std::string>& parameter_names() const override;
|
virtual const std::vector<std::string>& parameter_names() const override;
|
||||||
|
|
||||||
virtual const torch::lazy::Shape& result_shape() const override;
|
virtual const torch::lazy::Shape& result_shape() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MlirLoweringContext : public torch::lazy::LoweringContext {
|
class TORCH_API MlirLoweringContext : public torch::lazy::LoweringContext {
|
||||||
public:
|
public:
|
||||||
|
MlirLoweringContext(
|
||||||
MlirLoweringContext(const std::string& name, torch::lazy::BackendDevice device);
|
const std::string& name, torch::lazy::BackendDevice device);
|
||||||
MlirLoweringContext(const std::string& name,
|
MlirLoweringContext(
|
||||||
torch::lazy::BackendDevice device,
|
const std::string& name, torch::lazy::BackendDevice device,
|
||||||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||||
torch::lazy::Util::EmissionMap emit_status);
|
torch::lazy::Util::EmissionMap emit_status);
|
||||||
|
|
||||||
|
@ -51,16 +50,15 @@ class MlirLoweringContext : public torch::lazy::LoweringContext {
|
||||||
// Associates the given output with the input parameter of the given index and
|
// Associates the given output with the input parameter of the given index and
|
||||||
// shape. Only used for the operator-by-operator execution, mostly for
|
// shape. Only used for the operator-by-operator execution, mostly for
|
||||||
// debugging purposes.
|
// debugging purposes.
|
||||||
virtual void AddParameter(const torch::lazy::Output& output,
|
virtual void AddParameter(
|
||||||
size_t index,
|
const torch::lazy::Output& output, size_t index,
|
||||||
const torch::lazy::Shape& shape,
|
const torch::lazy::Shape& shape, const std::string& name) override;
|
||||||
const std::string& name) override;
|
|
||||||
|
|
||||||
// Build the computation capturing all the operations created with the
|
// Build the computation capturing all the operations created with the
|
||||||
// embedded builder (returned by the builder() API).
|
// embedded builder (returned by the builder() API).
|
||||||
virtual torch::lazy::ComputationPtr Build() override;
|
virtual torch::lazy::ComputationPtr Build() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<const torch::lazy::Node*> result_tuple_;
|
std::vector<const torch::lazy::Node*> result_tuple_;
|
||||||
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
|
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
//===- mlir_node.cpp ------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.cpp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <torch/csrc/lazy/core/cache.h>
|
||||||
|
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
#include "mlir_node.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
hash_t OperandHashes(
|
||||||
|
const OpList& operands, const hash_t& seed, const bool bakeInSizes) {
|
||||||
|
hash_t hash = seed;
|
||||||
|
for (auto& operand : operands) {
|
||||||
|
if (!operand) {
|
||||||
|
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto operand_hash =
|
||||||
|
bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
|
||||||
|
hash = HashCombine(hash, operand_hash);
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
hash_t GetOpHash(
|
||||||
|
OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) {
|
||||||
|
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
|
||||||
|
return HashCombine(h, hash_seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MlirNode::MlirNode(
|
||||||
|
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
|
||||||
|
hash_t hash_seed)
|
||||||
|
: Node(
|
||||||
|
op, num_outputs,
|
||||||
|
/* node_hash */ HashCombine(op.hash(), hash_seed),
|
||||||
|
/* dag_hash */
|
||||||
|
[&](bool bakeInSizes) -> hash_t {
|
||||||
|
return OperandHashes(
|
||||||
|
operands, HashCombine(op.hash(), hash_seed), bakeInSizes);
|
||||||
|
}),
|
||||||
|
shapes_(std::move(shapes)) {
|
||||||
|
for (auto& operand : operands) {
|
||||||
|
// Ideally, optional operands should be filtered by the leaf node classes,
|
||||||
|
// but it's just much easier to do it here.
|
||||||
|
if (!operand) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
AddOperand(operand.node, operand.index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirNode::MlirNode(
|
||||||
|
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
||||||
|
size_t num_outputs, hash_t hash_seed)
|
||||||
|
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
|
||||||
|
shapes_.push_back(GetOpShape(shape_fn));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirNode::MlirNode(
|
||||||
|
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
|
||||||
|
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
|
||||||
|
|
||||||
|
void MlirNode::SetShapeDeferred(const std::function<Shape()>& shape_fn) {
|
||||||
|
shapes_.push_back(GetOpShape(shape_fn));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
|
||||||
|
: Node(op, num_outputs, [&](bool bakeInSizes) -> hash_t {
|
||||||
|
return GetOpHash(op, shape, hash_seed, bakeInSizes);
|
||||||
|
}) {
|
||||||
|
shapes_.push_back(std::move(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
using ShapeCache = Cache<hash_t, Shape, HashReducer>;
|
||||||
|
|
||||||
|
constexpr const int torch_lazy_shape_cache_size = 4096;
|
||||||
|
|
||||||
|
ShapeCache* GetShapeCache() {
|
||||||
|
static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape MlirNode::GetOpShape(const std::function<Shape()>& shape_fn) const {
|
||||||
|
ShapeCache* shape_cache = GetShapeCache();
|
||||||
|
auto shape = shape_cache->Get(hash());
|
||||||
|
if (shape == nullptr) {
|
||||||
|
shape = shape_cache->Add(hash(), std::make_shared<Shape>(shape_fn()));
|
||||||
|
}
|
||||||
|
return *shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::ArrayRef<Shape> MlirNode::shapes() const { return shapes_; }
|
||||||
|
|
||||||
|
const Shape& MlirNode::shape(size_t output_index) const {
|
||||||
|
return shapes_.at(output_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<Output>& MlirNode::operands() const {
|
||||||
|
return operands_as_outputs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Output& MlirNode::operand(size_t i) const {
|
||||||
|
return operands_as_outputs_.at(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirNode::AddOperand(NodePtr node, size_t index) {
|
||||||
|
CHECK_LT(index, node->num_outputs());
|
||||||
|
operands_.push_back(std::move(node));
|
||||||
|
operands_as_outputs_.emplace_back(operands_.back().get(), index);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -14,12 +14,12 @@
|
||||||
|
|
||||||
#include <ATen/core/interned_strings.h>
|
#include <ATen/core/interned_strings.h>
|
||||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
#include <torch/csrc/lazy/core/shape.h>
|
|
||||||
#include <torch/csrc/lazy/core/ir.h>
|
#include <torch/csrc/lazy/core/ir.h>
|
||||||
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
|
#include "../utils/exception.h"
|
||||||
#include "aten_eager_fallback.h"
|
#include "aten_eager_fallback.h"
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
#include "../utils/exception.h"
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
@ -27,48 +27,47 @@ namespace lazy {
|
||||||
typedef std::vector<NodePtr> MlirOpVector;
|
typedef std::vector<NodePtr> MlirOpVector;
|
||||||
typedef NodePtr MlirFunction;
|
typedef NodePtr MlirFunction;
|
||||||
|
|
||||||
|
class TORCH_API MlirNode : public torch::lazy::Node {
|
||||||
|
|
||||||
class MlirNode : public torch::lazy::Node {
|
public:
|
||||||
|
|
||||||
public:
|
|
||||||
MlirNode(
|
MlirNode(
|
||||||
OpKind op, OpList operands, std::vector<Shape>&& shapes,
|
OpKind op, OpList operands, std::vector<Shape>&& shapes,
|
||||||
size_t num_outputs = 1, hash_t hash_seed = kHashSeed
|
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
|
||||||
);
|
|
||||||
|
|
||||||
// Same as the constructor above, but the shape is generated by a function,
|
// Same as the constructor above, but the shape is generated by a function,
|
||||||
// only if needed (shape cache miss).
|
// only if needed (shape cache miss).
|
||||||
MlirNode(
|
MlirNode(
|
||||||
OpKind op, OpList operands,
|
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
||||||
const std::function<Shape()>& shape_fn,
|
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
|
||||||
size_t num_outputs = 1, hash_t hash_seed = kHashSeed
|
|
||||||
);
|
|
||||||
|
|
||||||
// The shape is set later.
|
// The shape is set later.
|
||||||
MlirNode(
|
MlirNode(
|
||||||
OpKind op, OpList operands, size_t num_outputs = 1,
|
OpKind op, OpList operands, size_t num_outputs = 1,
|
||||||
hash_t hash_seed = kHashSeed
|
hash_t hash_seed = kHashSeed);
|
||||||
);
|
|
||||||
|
|
||||||
void SetShapeDeferred(const std::function<Shape()>& shape_fn);
|
void SetShapeDeferred(const std::function<Shape()>& shape_fn);
|
||||||
|
|
||||||
// Contructor used to create leaf nodes.
|
// Contructor used to create leaf nodes.
|
||||||
MlirNode(
|
MlirNode(
|
||||||
OpKind op, Shape shape, size_t num_outputs = 1, hash_t hash_seed = kHashSeed
|
OpKind op, Shape shape, size_t num_outputs = 1,
|
||||||
);
|
hash_t hash_seed = kHashSeed);
|
||||||
|
|
||||||
Shape GetOpShape(const std::function<Shape()>& shape_fn) const;
|
Shape GetOpShape(const std::function<Shape()>& shape_fn) const;
|
||||||
|
|
||||||
|
// Retrieves the full shape of the IR Node.
|
||||||
|
c10::ArrayRef<Shape> shapes() const override;
|
||||||
|
|
||||||
|
// Retrieves the shape of the output at a given index.
|
||||||
|
const Shape& shape(size_t output_index = 0) const override;
|
||||||
|
|
||||||
const std::vector<Output>& operands() const override;
|
const std::vector<Output>& operands() const override;
|
||||||
|
|
||||||
const Output& operand(size_t i) const override;
|
const Output& operand(size_t i) const override;
|
||||||
|
|
||||||
virtual MlirOpVector Lower(
|
virtual MlirOpVector
|
||||||
MlirFunction function,
|
Lower(MlirFunction function, MlirLoweringContext* loctx) const = 0;
|
||||||
MlirLoweringContext* loctx
|
|
||||||
) const = 0;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Adds node's index output number as operand.
|
// Adds node's index output number as operand.
|
||||||
void AddOperand(NodePtr node, size_t index = 0);
|
void AddOperand(NodePtr node, size_t index = 0);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,242 @@
|
||||||
|
//===- tensor_aten_ops.cpp ------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.cpp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include <ATen/InferSize.h>
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
#include <torch/csrc/autograd/variable.h>
|
||||||
|
#include <torch/csrc/lazy/core/helpers.h>
|
||||||
|
#include <torch/csrc/lazy/core/ir_util.h>
|
||||||
|
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||||
|
#include <torch/csrc/lazy/core/metrics.h>
|
||||||
|
#include <torch/csrc/lazy/core/tensor.h>
|
||||||
|
#include <torch/csrc/lazy/core/util.h>
|
||||||
|
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
||||||
|
#include <torch/csrc/lazy/core/view_ops/permute.h>
|
||||||
|
#include <torch/csrc/lazy/core/view_ops/view.h>
|
||||||
|
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
|
||||||
|
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
|
||||||
|
|
||||||
|
#include "tensor_aten_ops.h"
|
||||||
|
|
||||||
|
namespace torch_lazy_tensors {
|
||||||
|
namespace lazy_tensor_aten_ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// to enable operator+-*/ for Value
|
||||||
|
using namespace torch::lazy;
|
||||||
|
|
||||||
|
torch::lazy::Value MaybeExpand(
|
||||||
|
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
|
||||||
|
if (input.shape().sizes() == target_shape.sizes()) {
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
return torch::lazy::MakeNode<torch::lazy::Expand>(
|
||||||
|
input, target_shape.sizes().vec(),
|
||||||
|
/*is_scalar_expand=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> GetExpandDimensions(
|
||||||
|
const torch::lazy::Shape& shape, std::vector<int64_t> dimensions) {
|
||||||
|
CHECK_GE(dimensions.size(), shape.dim()) << shape;
|
||||||
|
int64_t base = dimensions.size() - shape.dim();
|
||||||
|
for (size_t i = 0; i < shape.dim(); ++i) {
|
||||||
|
if (dimensions[base + i] == -1) {
|
||||||
|
dimensions[base + i] = shape.size(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
|
||||||
|
torch::lazy::Shape
|
||||||
|
BatchNormFeaturesShape(const torch::lazy::LazyTensorPtr& input) {
|
||||||
|
CHECK(input);
|
||||||
|
auto input_shape = input->shape().Get();
|
||||||
|
return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the IR for the given input or the provided default value broadcasted
|
||||||
|
// to the default shape, if the input is undefined.
|
||||||
|
torch::lazy::Value GetIrValueOrDefault(
|
||||||
|
const torch::lazy::LazyTensorPtr& input, const at::Scalar& default_value,
|
||||||
|
const torch::lazy::Shape& default_shape,
|
||||||
|
const torch::lazy::BackendDevice& device) {
|
||||||
|
return input ? input->GetIrValue()
|
||||||
|
: torch::lazy::LazyGraphExecutor::Get()
|
||||||
|
->GetIrValueForExpandedScalar(
|
||||||
|
default_value, default_shape, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::ViewInfo CreateAsStridedViewInfo(
|
||||||
|
const torch::lazy::Shape& input_shape, std::vector<int64_t> size,
|
||||||
|
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
|
||||||
|
torch::lazy::Shape result_shape =
|
||||||
|
torch::lazy::Shape(input_shape.scalar_type(), size);
|
||||||
|
torch::lazy::AsStridedInfo as_strided_info;
|
||||||
|
as_strided_info.stride = std::move(stride);
|
||||||
|
if (storage_offset) {
|
||||||
|
as_strided_info.offset = *storage_offset;
|
||||||
|
}
|
||||||
|
return torch::lazy::ViewInfo(
|
||||||
|
torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape),
|
||||||
|
input_shape, std::move(as_strided_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
// ATEN operators follows here, listed in alphabetical order.
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
torch::lazy::LazyTensorPtr as_strided(
|
||||||
|
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
||||||
|
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
return input->CreateViewTensor(CreateAsStridedViewInfo(
|
||||||
|
input_shape, std::move(size), std::move(stride), storage_offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
void as_strided_(
|
||||||
|
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
||||||
|
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
|
||||||
|
if (input->data()->view == nullptr) {
|
||||||
|
input->SetIrValue(torch::lazy::MakeNode<torch::lazy::AsStrided>(
|
||||||
|
input->GetIrValue(), std::move(size), std::move(stride),
|
||||||
|
storage_offset.value_or(0)));
|
||||||
|
} else {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
input->SetSubView(CreateAsStridedViewInfo(
|
||||||
|
input_shape, std::move(size), std::move(stride), storage_offset));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr
|
||||||
|
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size) {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
return torch::lazy::LazyTensor::Create(
|
||||||
|
torch::lazy::MakeNode<torch::lazy::Expand>(
|
||||||
|
input->GetIrValue(),
|
||||||
|
GetExpandDimensions(input_shape.Get(), std::move(size)),
|
||||||
|
/*is_scalar_expand=*/false),
|
||||||
|
input->GetDevice());
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) {
|
||||||
|
torch::lazy::Value constant =
|
||||||
|
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
|
||||||
|
value, input->shape(), input->GetDevice());
|
||||||
|
input->SetInPlaceIrValue(std::move(constant));
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr narrow(
|
||||||
|
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
||||||
|
int64_t length) {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
|
||||||
|
torch::lazy::Shape narrow_shape = input_shape;
|
||||||
|
narrow_shape.set_size(dim, length);
|
||||||
|
|
||||||
|
torch::lazy::ViewInfo::Type view_type =
|
||||||
|
(input_shape.Get().numel() == narrow_shape.numel())
|
||||||
|
? torch::lazy::ViewInfo::Type::kReshape
|
||||||
|
: torch::lazy::ViewInfo::Type::kNarrow;
|
||||||
|
torch::lazy::ViewInfo view_info(
|
||||||
|
view_type, std::move(narrow_shape), input_shape);
|
||||||
|
view_info.indices[dim] =
|
||||||
|
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
|
||||||
|
return input->CreateViewTensor(std::move(view_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr
|
||||||
|
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims) {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
torch::lazy::ViewInfo view_info(
|
||||||
|
torch::lazy::ViewInfo::Type::kPermute, input_shape,
|
||||||
|
torch::lazy::GetCanonicalDimensionIndices(dims, input_shape.Get().dim()));
|
||||||
|
return input->CreateViewTensor(std::move(view_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
|
||||||
|
if (input->GetDevice() == src->GetDevice()) {
|
||||||
|
torch::lazy::Value copy_value;
|
||||||
|
if (input->dtype() == src->dtype()) {
|
||||||
|
copy_value = src->GetIrValue();
|
||||||
|
} else {
|
||||||
|
copy_value = torch::lazy::MakeNode<torch::lazy::Cast>(
|
||||||
|
src->GetIrValue(), input->dtype(), src->dtype());
|
||||||
|
}
|
||||||
|
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
|
||||||
|
} else {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
at::Tensor src_tensor = src->ToTensor(/*detached=*/true);
|
||||||
|
if (src_tensor.sizes() != input_shape.Get().sizes()) {
|
||||||
|
src_tensor = src_tensor.expand(input_shape.Get().sizes().vec());
|
||||||
|
}
|
||||||
|
input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr slice(
|
||||||
|
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
||||||
|
int64_t end, int64_t step) {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
|
||||||
|
start =
|
||||||
|
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
|
||||||
|
end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end);
|
||||||
|
// PyTorch allows tensor[-1:0] to return a 0-dim tensor.
|
||||||
|
if (start > end) {
|
||||||
|
end = start;
|
||||||
|
}
|
||||||
|
step = std::min(step, end - start);
|
||||||
|
|
||||||
|
torch::lazy::SelectInfo select = {dim, start, end, step};
|
||||||
|
torch::lazy::ViewInfo view_info(
|
||||||
|
torch::lazy::ViewInfo::Type::kSelect, input_shape, std::move(select));
|
||||||
|
return input->CreateViewTensor(std::move(view_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr
|
||||||
|
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
auto permute_dims = torch::lazy::MakeTransposePermutation(
|
||||||
|
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
|
||||||
|
torch::lazy::ViewInfo view_info(
|
||||||
|
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
|
||||||
|
return input->CreateViewTensor(std::move(view_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
|
||||||
|
auto input_shape = input->shape();
|
||||||
|
auto permute_dims = torch::lazy::MakeTransposePermutation(
|
||||||
|
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
|
||||||
|
torch::lazy::ViewInfo view_info(
|
||||||
|
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
|
||||||
|
return input->ModifyCurrentView(std::move(view_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr view(
|
||||||
|
const torch::lazy::LazyTensorPtr& input,
|
||||||
|
c10::ArrayRef<int64_t> output_size) {
|
||||||
|
auto input_shape = input->shape().Get();
|
||||||
|
torch::lazy::Shape shape = torch::lazy::Shape(
|
||||||
|
input_shape.scalar_type(),
|
||||||
|
at::infer_size(output_size, input_shape.numel()));
|
||||||
|
torch::lazy::ViewInfo view_info(
|
||||||
|
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
|
||||||
|
return input->CreateViewTensor(std::move(view_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy_tensor_aten_ops
|
||||||
|
} // namespace torch_lazy_tensors
|
|
@ -0,0 +1,79 @@
|
||||||
|
//===- tensor_aten_ops.h --------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This file is adapted from pytorch/pytorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.h
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/lazy/core/tensor.h>
|
||||||
|
|
||||||
|
namespace torch_lazy_tensors {
|
||||||
|
namespace lazy_tensor_aten_ops {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
// ATEN operators follows here, listed in alphabetical order.
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Takes a slice from the input as R1 at the specified offset and reshapes it
|
||||||
|
// into the provided size.
|
||||||
|
torch::lazy::LazyTensorPtr as_strided(
|
||||||
|
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
||||||
|
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
|
||||||
|
|
||||||
|
// In-place version of the method above.
|
||||||
|
void as_strided_(
|
||||||
|
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
||||||
|
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr
|
||||||
|
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size);
|
||||||
|
|
||||||
|
// Fills the input with the given value.
|
||||||
|
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value);
|
||||||
|
|
||||||
|
// Returns a new tensor that is a narrowed view of the input in the given
|
||||||
|
// dimension.
|
||||||
|
torch::lazy::LazyTensorPtr narrow(
|
||||||
|
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
||||||
|
int64_t length);
|
||||||
|
|
||||||
|
// Permute the dimensions of this tensor according to the given permutation.
|
||||||
|
torch::lazy::LazyTensorPtr
|
||||||
|
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims);
|
||||||
|
|
||||||
|
// Repeats the input tensor along each dimension by the given number of
|
||||||
|
// repeats.
|
||||||
|
torch::lazy::LazyTensorPtr
|
||||||
|
repeat(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> repeats);
|
||||||
|
|
||||||
|
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src);
|
||||||
|
|
||||||
|
torch::lazy::LazyTensorPtr slice(
|
||||||
|
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
||||||
|
int64_t end, int64_t step);
|
||||||
|
|
||||||
|
std::tuple<
|
||||||
|
torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr,
|
||||||
|
torch::lazy::LazyTensorPtr>
|
||||||
|
svd(const torch::lazy::LazyTensorPtr& input, bool some, bool compute_uv);
|
||||||
|
|
||||||
|
// Swap given dimensions of the input.
|
||||||
|
torch::lazy::LazyTensorPtr
|
||||||
|
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
|
||||||
|
|
||||||
|
// In-place version of the method above.
|
||||||
|
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
|
||||||
|
|
||||||
|
// Like reshape, but it returns a view into the original tensor.
|
||||||
|
torch::lazy::LazyTensorPtr view(
|
||||||
|
const torch::lazy::LazyTensorPtr& input,
|
||||||
|
c10::ArrayRef<int64_t> output_size);
|
||||||
|
|
||||||
|
} // namespace lazy_tensor_aten_ops
|
||||||
|
} // namespace torch_lazy_tensors
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===- debug.h ------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "sys_utils.h"
|
||||||
|
|
||||||
|
static const bool verbose_print_function =
|
||||||
|
sys_util::GetEnvBool("VERBOSE_PRINT_FUNCTION", false);
|
||||||
|
|
||||||
|
#define PRINT_FUNCTION() \
|
||||||
|
if (verbose_print_function) { \
|
||||||
|
std::cout << __PRETTY_FUNCTION__ << " (" << __FILE__ << ":" << __LINE__ \
|
||||||
|
<< ")" << std::endl; \
|
||||||
|
}
|
|
@ -20,6 +20,9 @@
|
||||||
throw std::runtime_error(err.str()); \
|
throw std::runtime_error(err.str()); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define UNIMPLEMENTED_FUNCTION_ERROR() \
|
||||||
|
UNIMPLEMENTED_ERROR( \
|
||||||
|
"\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__)
|
||||||
|
|
||||||
#define UNSUPPORTED_ERROR(msg) \
|
#define UNSUPPORTED_ERROR(msg) \
|
||||||
{ \
|
{ \
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
namespace sys_util {
|
||||||
|
|
||||||
|
static bool GetEnvBool(const char* name, bool defval) {
|
||||||
|
const char* env = std::getenv(name);
|
||||||
|
if (env == nullptr) {
|
||||||
|
return defval;
|
||||||
|
}
|
||||||
|
if (std::strcmp(env, "true") == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (std::strcmp(env, "false") == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::atoi(env) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sys_util
|
Loading…
Reference in New Issue