mirror of https://github.com/llvm/torch-mlir
Fix Torch-MLIR LTC Backend based off latest PyTorch master (#723)
* Changes as a result of the LTC TS backend decoupling * Fix bugs in BackendImpl and codegen * Fix based on latest PyTorch masterpull/1125/head
parent
c3b20e444c
commit
65cf1465ef
|
@ -26,7 +26,7 @@ bazel-*
|
||||||
# Autogenerated files
|
# Autogenerated files
|
||||||
/generated_native_functions.yaml
|
/generated_native_functions.yaml
|
||||||
/generated_backend.hash
|
/generated_backend.hash
|
||||||
/python/torch_mlir/csrc/backend/LazyLazyIr.h
|
/python/torch_mlir/csrc/backend/LazyIr.h
|
||||||
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
|
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
|
||||||
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
|
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
|
||||||
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
|
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
|
||||||
|
|
|
@ -24,6 +24,10 @@ from codegen.gen import get_grouped_native_functions, parse_native_yaml
|
||||||
from codegen.model import NativeFunctionsGroup
|
from codegen.model import NativeFunctionsGroup
|
||||||
|
|
||||||
|
|
||||||
|
def isOptionalCType(arg):
|
||||||
|
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"
|
||||||
|
|
||||||
|
|
||||||
def generate_native_functions(
|
def generate_native_functions(
|
||||||
config_path: Path, torch_ops_file: Path, out_file: Path
|
config_path: Path, torch_ops_file: Path, out_file: Path
|
||||||
):
|
):
|
||||||
|
@ -98,7 +102,7 @@ def generate_native_functions(
|
||||||
yaml.dump(
|
yaml.dump(
|
||||||
{
|
{
|
||||||
"backend": "Lazy",
|
"backend": "Lazy",
|
||||||
"cpp_namespace": "torch_lazy_tensors",
|
"cpp_namespace": "torch::lazy",
|
||||||
"full_codegen": opnames,
|
"full_codegen": opnames,
|
||||||
"supported": sorted(supported_ops),
|
"supported": sorted(supported_ops),
|
||||||
},
|
},
|
||||||
|
@ -120,21 +124,46 @@ def generate_native_functions(
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
|
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
|
||||||
lowering_function_type: str = "torch::lazy::MlirFunction"
|
|
||||||
lowering_context_type: str = "torch::lazy::MlirLoweringContext*"
|
|
||||||
lowering_return_type: str = "torch::lazy::MlirOpVector"
|
|
||||||
|
|
||||||
def lowering_body(self, f):
|
def lowering_function(self, f):
|
||||||
func = (
|
func = (
|
||||||
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||||
)
|
)
|
||||||
schema = LazyIrSchema(func)
|
schema = LazyIrSchema(func)
|
||||||
|
|
||||||
|
emplace_arguments = []
|
||||||
|
for arg in schema.positional_args:
|
||||||
|
if arg.is_lazy_value:
|
||||||
|
if isOptionalCType(arg.lazy_type):
|
||||||
|
emplace_arguments.append(f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr")
|
||||||
|
continue
|
||||||
|
emplace_arguments.append('loctx->GetOutputOp(operand(i++))')
|
||||||
|
continue
|
||||||
|
emplace_arguments.append(f'"{arg.name}", {arg.name}')
|
||||||
|
|
||||||
|
emplace_arguments_str = "\n ".join(
|
||||||
|
[f"arguments.emplace_back({a});" for a in emplace_arguments])
|
||||||
|
emplace_kwarg_values = [f'"{t.name}", loctx->GetOutputOp(operand(i++))' for t in schema.keyword_values]
|
||||||
|
emplace_kwarg_scalars = [f'"{t.name}", {t.name}' for t in schema.keyword_scalars]
|
||||||
|
emplace_kwarguments = "\n ".join(
|
||||||
|
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])
|
||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
UNIMPLEMENTED_ERROR(
|
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override {{
|
||||||
"'{func}' lowering not yet implemented"
|
PRINT_FUNCTION();
|
||||||
);
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
""".rstrip()
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
arguments.reserve({len(emplace_arguments)});
|
||||||
|
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
|
||||||
|
size_t i = 0;
|
||||||
|
{emplace_arguments_str}
|
||||||
|
{emplace_kwarguments}
|
||||||
|
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments);
|
||||||
|
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
|
||||||
|
|
||||||
|
return {schema.aten_name}_out;
|
||||||
|
}}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
def generate_backend(
|
def generate_backend(
|
||||||
|
@ -151,14 +180,14 @@ def generate_backend(
|
||||||
|
|
||||||
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
|
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
|
||||||
|
|
||||||
codegen.gen_lazy_tensor.run(
|
codegen.gen_lazy_tensor.run_gen_lazy_tensor(
|
||||||
backend_name="TorchMlir",
|
backend_name="TorchMlir",
|
||||||
|
aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")),
|
||||||
source_yaml=str(source_yaml),
|
source_yaml=str(source_yaml),
|
||||||
output_dir=str(backend_path),
|
output_dir=str(backend_path),
|
||||||
dry_run=False,
|
dry_run=False,
|
||||||
impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")),
|
impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")),
|
||||||
gen_ts_lowerings=False,
|
node_base="torch::lazy::TorchMlirNode",
|
||||||
node_base="torch::lazy::MlirNode",
|
|
||||||
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
|
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
|
||||||
tensor_class="torch::lazy::LazyTensor",
|
tensor_class="torch::lazy::LazyTensor",
|
||||||
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
|
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
|
||||||
|
@ -298,7 +327,6 @@ def main(args):
|
||||||
new_hash = m.hexdigest().strip()
|
new_hash = m.hexdigest().strip()
|
||||||
|
|
||||||
if args.force or new_hash != prev_hash:
|
if args.force or new_hash != prev_hash:
|
||||||
hash_file.write_text(new_hash)
|
|
||||||
parsed_yaml, grouped_native_functions = generate_native_functions(
|
parsed_yaml, grouped_native_functions = generate_native_functions(
|
||||||
config_path, torch_ops_file, native_functions
|
config_path, torch_ops_file, native_functions
|
||||||
)
|
)
|
||||||
|
@ -310,6 +338,8 @@ def main(args):
|
||||||
grouped_native_functions,
|
grouped_native_functions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hash_file.write_text(new_hash)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
@ -37,10 +37,9 @@ supported:
|
||||||
- empty
|
- empty
|
||||||
- expand
|
- expand
|
||||||
- fill_
|
- fill_
|
||||||
# - native_batch_norm_backward
|
|
||||||
- native_batch_norm
|
- native_batch_norm
|
||||||
|
# - native_batch_norm_backward
|
||||||
- permute
|
- permute
|
||||||
- repeat
|
|
||||||
- squeeze
|
- squeeze
|
||||||
- t
|
- t
|
||||||
- unsqueeze
|
- unsqueeze
|
||||||
|
@ -50,3 +49,4 @@ additional_ops:
|
||||||
# Additional ops to support that are not supported by Torch-MLIR explicitly
|
# Additional ops to support that are not supported by Torch-MLIR explicitly
|
||||||
- _copy_from
|
- _copy_from
|
||||||
- _copy_from_and_resize
|
- _copy_from_and_resize
|
||||||
|
- native_batch_norm_backward
|
||||||
|
|
|
@ -20,16 +20,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
|
||||||
|
|
||||||
add_library(torch_mlir_ltc_backend SHARED
|
add_library(torch_mlir_ltc_backend SHARED
|
||||||
backend/aten_eager_fallback.cpp
|
|
||||||
backend/aten_ltc_mlir_type.cpp
|
|
||||||
backend/backend_impl.cpp
|
backend/backend_impl.cpp
|
||||||
backend/LazyNativeFunctions.cpp
|
backend/LazyNativeFunctions.cpp
|
||||||
backend/LazyShapeInference.cpp
|
backend/LazyShapeInference.cpp
|
||||||
backend/GenLazyShapeInference.cpp
|
backend/GenLazyShapeInference.cpp
|
||||||
backend/mlir_lowering_context.cpp
|
backend/mlir_lowering_context.cpp
|
||||||
|
backend/mlir_native_functions.cpp
|
||||||
backend/mlir_node.cpp
|
backend/mlir_node.cpp
|
||||||
backend/RegisterLazy.cpp
|
backend/RegisterLazy.cpp
|
||||||
tensor_aten_ops.cpp
|
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(torch_mlir_ltc_backend
|
target_link_libraries(torch_mlir_ltc_backend
|
||||||
|
|
|
@ -69,6 +69,7 @@ TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at
|
||||||
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
|
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
|
||||||
TORCH_API std::vector<Shape> compute_shape_relu(const at::Tensor & self);
|
TORCH_API std::vector<Shape> compute_shape_relu(const at::Tensor & self);
|
||||||
TORCH_API std::vector<Shape> compute_shape_relu_(at::Tensor & self);
|
TORCH_API std::vector<Shape> compute_shape_relu_(at::Tensor & self);
|
||||||
|
TORCH_API std::vector<Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
|
||||||
TORCH_API std::vector<Shape> compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape);
|
TORCH_API std::vector<Shape> compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape);
|
||||||
TORCH_API std::vector<Shape> compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
TORCH_API std::vector<Shape> compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||||
TORCH_API std::vector<Shape> compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index);
|
TORCH_API std::vector<Shape> compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index);
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
//===- aten_eager_fallback.cpp --------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// This file is adapted from pytorch/pytorch
|
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.cpp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <torch/csrc/lazy/backend/backend_interface.h>
|
|
||||||
#include <torch/csrc/lazy/core/metrics.h>
|
|
||||||
|
|
||||||
#include "../utils/exception.h"
|
|
||||||
#include "aten_eager_fallback.h"
|
|
||||||
|
|
||||||
namespace torch_lazy_tensors {
|
|
||||||
|
|
||||||
static std::unordered_map<std::string, ::torch::lazy::Counter*>
|
|
||||||
_eager_fallback_counters;
|
|
||||||
|
|
||||||
bool force_eager_fallback(c10::Symbol op) {
|
|
||||||
static char* force_str = std::getenv("LTC_FORCE_FALLBACK");
|
|
||||||
if (force_str != nullptr) {
|
|
||||||
static auto force_sym = c10::Symbol::fromQualString(std::string(force_str));
|
|
||||||
if (op == force_sym) {
|
|
||||||
std::cout << "MLIR force_eager_fallback(" << force_str << "): true"
|
|
||||||
<< std::endl;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::cout << "MLIR force_eager_fallback(" << op.toQualString() << "): false"
|
|
||||||
<< std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ltc_eager_fallback(
|
|
||||||
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|
||||||
const auto name = c10::toString(op.operator_name());
|
|
||||||
UNSUPPORTED_ERROR(
|
|
||||||
"MLIR ltc_eager_fallback is not supported for op: " << name);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::function<void(void)> register_mlir_ltc_eager_fallback;
|
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, Lazy, m) {
|
|
||||||
register_mlir_ltc_eager_fallback = [&]() {
|
|
||||||
m.fallback(
|
|
||||||
torch::CppFunction::makeFromBoxedFunction<<c_eager_fallback>());
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch_lazy_tensors
|
|
|
@ -1,27 +0,0 @@
|
||||||
//===- aten_eager_fallback.h ----------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Facilitates eager fallback behaviour
|
|
||||||
//
|
|
||||||
// This file is adapted from pytorch/pytorch
|
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.h
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/native/CPUFallback.h>
|
|
||||||
|
|
||||||
namespace torch_lazy_tensors {
|
|
||||||
|
|
||||||
bool force_eager_fallback(c10::Symbol op);
|
|
||||||
void ltc_eager_fallback(
|
|
||||||
const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
|
||||||
|
|
||||||
extern TORCH_API std::function<void(void)> register_mlir_ltc_eager_fallback;
|
|
||||||
|
|
||||||
} // namespace torch_lazy_tensors
|
|
|
@ -7,7 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// This file is adapted from pytorch/pytorch
|
// This file is adapted from pytorch/pytorch
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.cpp
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
|
@ -23,77 +23,79 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
MlirBackendData::MlirBackendData(BackendDevice device, Shape shape)
|
TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
|
||||||
: BackendData(device, shape) {
|
: BackendData(device, shape),
|
||||||
|
info_(std::make_unique<TorchMlirBackendData::Info>()) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
auto info = std::make_shared<MlirBackendData::Info>();
|
|
||||||
SetInfo(info);
|
|
||||||
}
|
}
|
||||||
MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device)
|
TorchMlirBackendData::TorchMlirBackendData(
|
||||||
: BackendData(device, Shape(scalar.type(), {})) {
|
const at::Scalar& scalar, BackendDevice device)
|
||||||
|
: BackendData(device, Shape(scalar.type(), {})),
|
||||||
|
info_(std::make_unique<TorchMlirBackendData::Info>(scalar)) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
auto info = std::make_shared<MlirBackendData::Info>(scalar);
|
|
||||||
SetInfo(info);
|
|
||||||
}
|
}
|
||||||
MlirBackendData::MlirBackendData(
|
TorchMlirBackendData::TorchMlirBackendData(
|
||||||
const at::Tensor& tensor, BackendDevice device, Shape shape)
|
const at::Tensor& tensor, BackendDevice device, Shape shape)
|
||||||
: BackendData(device, shape) {
|
: BackendData(device, shape),
|
||||||
|
info_(std::make_unique<TorchMlirBackendData::Info>(tensor)) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
auto info = std::make_shared<MlirBackendData::Info>(tensor);
|
|
||||||
SetInfo(info);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendData::Handle MlirBackendData::GetHandle() {
|
BackendData::Handle TorchMlirBackendData::GetHandle() {
|
||||||
return reinterpret_cast<int64_t>(this);
|
return reinterpret_cast<int64_t>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MlirBackendData::Assign(const BackendData& data) {
|
void TorchMlirBackendData::Assign(const BackendData& data) {
|
||||||
MlirBackendData::Info* info =
|
TorchMlirBackendData::Info* info =
|
||||||
dynamic_cast<MlirBackendData::Info*>(data.info());
|
dynamic_cast<TorchMlirBackendData::Info*>(data.info());
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
|
info,
|
||||||
auto new_info = std::make_shared<MlirBackendData::Info>(*info);
|
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
||||||
SetInfo(new_info);
|
info_ = std::make_unique<TorchMlirBackendData::Info>(*info);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MlirBackendData::HasValue() const { return bool(info()); }
|
bool TorchMlirBackendData::HasValue() const { return bool(info_); }
|
||||||
|
|
||||||
|
TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
|
||||||
|
return info_.get();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialization/Teardown
|
* Initialization/Teardown
|
||||||
* */
|
* */
|
||||||
void MlirBackendImpl::PrepareToExit() const {}
|
void TorchMlirBackendImpl::PrepareToExit() const {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Data Transfer
|
* Data Transfer
|
||||||
* */
|
* */
|
||||||
|
|
||||||
BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor(
|
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor(
|
||||||
const at::Tensor& tensor, const Shape& shape,
|
const at::Tensor& tensor, const Shape& shape,
|
||||||
const BackendDevice& device) const {
|
const BackendDevice& device) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_shared<MlirBackendData>(tensor, device, shape);
|
return std::make_shared<TorchMlirBackendData>(tensor, device, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar(
|
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar(
|
||||||
const at::Scalar& scalar, const BackendDevice& device) const {
|
const at::Scalar& scalar, const BackendDevice& device) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_shared<MlirBackendData>(scalar, device);
|
return std::make_shared<TorchMlirBackendData>(scalar, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendDataPtr MlirBackendImpl::CreateDataPlaceholder(
|
BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
|
||||||
const BackendDevice& device, const Shape& shape) const {
|
const BackendDevice& device, const Shape& shape) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_shared<MlirBackendData>(device, shape);
|
return std::make_shared<TorchMlirBackendData>(device, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
|
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
||||||
const BackendDataPtr data,
|
const BackendDataPtr data,
|
||||||
c10::optional<at::ScalarType> logical_scalar_type) const {
|
c10::optional<at::ScalarType> logical_scalar_type) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
MlirBackendData::Info* info =
|
TorchMlirBackendData::Info* info =
|
||||||
dynamic_cast<MlirBackendData::Info*>(data->info());
|
dynamic_cast<TorchMlirBackendData::Info*>(data->info());
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
|
info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
||||||
return info->tensor;
|
return info->tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,20 +103,20 @@ at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
|
||||||
* Lowering, Compilation, Execution
|
* Lowering, Compilation, Execution
|
||||||
* */
|
* */
|
||||||
|
|
||||||
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
|
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
||||||
const std::string& name, BackendDevice device,
|
const std::string& name, BackendDevice device,
|
||||||
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const {
|
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_unique<MlirLoweringContext>(
|
return std::make_unique<TorchMlirLoweringContext>(
|
||||||
name, std::forward<BackendDevice>(device),
|
name, std::forward<BackendDevice>(device),
|
||||||
std::forward<c10::ArrayRef<Node*>>(post_order),
|
std::forward<c10::ArrayRef<Node*>>(post_order),
|
||||||
std::forward<Util::EmissionMap>(emit_status));
|
std::forward<Util::EmissionMap>(emit_status));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
|
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
||||||
const std::string& name, BackendDevice device) const {
|
const std::string& name, BackendDevice device) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_unique<MlirLoweringContext>(
|
return std::make_unique<TorchMlirLoweringContext>(
|
||||||
name, std::forward<BackendDevice>(device));
|
name, std::forward<BackendDevice>(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,13 +131,13 @@ std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
|
||||||
|
|
||||||
// Specify which aten device should be used for eager fallback
|
// Specify which aten device should be used for eager fallback
|
||||||
// may change depending on current 'Default' DeviceType
|
// may change depending on current 'Default' DeviceType
|
||||||
at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const {
|
at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return at::DeviceType::CPU;
|
return at::DeviceType::CPU;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query all available backend devices
|
// Query all available backend devices
|
||||||
std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
|
std::vector<BackendDevice> TorchMlirBackendImpl::GetBackendDevices() const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return {
|
return {
|
||||||
GetBackendDevice(c10::Device(c10::kLazy, 0)),
|
GetBackendDevice(c10::Device(c10::kLazy, 0)),
|
||||||
|
@ -148,7 +150,7 @@ std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
|
||||||
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
||||||
// through a mode, in which case these APIs should still work, but should be
|
// through a mode, in which case these APIs should still work, but should be
|
||||||
// identity mappings.
|
// identity mappings.
|
||||||
BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const {
|
BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return BackendDevice(GetDefaultDeviceType(), device.index());
|
return BackendDevice(GetDefaultDeviceType(), device.index());
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
// using the Torch-MLIR ATen dialect
|
// using the Torch-MLIR ATen dialect
|
||||||
//
|
//
|
||||||
// This file is adapted from pytorch/pytorch
|
// This file is adapted from pytorch/pytorch
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.h
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
@ -23,7 +23,7 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class TORCH_API MlirBackendData : public BackendData {
|
class TORCH_API TorchMlirBackendData : public BackendData {
|
||||||
public:
|
public:
|
||||||
struct Info : public BackendData::Info {
|
struct Info : public BackendData::Info {
|
||||||
at::Tensor tensor;
|
at::Tensor tensor;
|
||||||
|
@ -39,20 +39,25 @@ public:
|
||||||
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
|
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
MlirBackendData(BackendDevice device, Shape shape);
|
TorchMlirBackendData(BackendDevice device, Shape shape);
|
||||||
MlirBackendData(const at::Scalar& scalar, BackendDevice device);
|
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
|
||||||
MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
|
TorchMlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
|
||||||
|
|
||||||
virtual BackendData::Handle GetHandle() override;
|
virtual BackendData::Handle GetHandle() override;
|
||||||
|
|
||||||
virtual void Assign(const BackendData& data) override;
|
virtual void Assign(const BackendData& data) override;
|
||||||
|
|
||||||
virtual bool HasValue() const override;
|
virtual bool HasValue() const override;
|
||||||
|
|
||||||
|
TorchMlirBackendData::Info* mlir_info() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<TorchMlirBackendData::Info> info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API MlirBackendImpl : public BackendImplInterface {
|
class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
|
||||||
public:
|
public:
|
||||||
virtual ~MlirBackendImpl() = default;
|
virtual ~TorchMlirBackendImpl() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialization/Teardown
|
* Initialization/Teardown
|
||||||
|
|
|
@ -7,22 +7,23 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// This file is adapted from pytorch/pytorch
|
// This file is adapted from pytorch/pytorch
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/ts_lowering_context.cpp
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "../utils/debug.h"
|
||||||
#include "../utils/exception.h"
|
#include "../utils/exception.h"
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
MlirLoweringContext::MlirLoweringContext(
|
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
const std::string& name, BackendDevice device)
|
const std::string& name, BackendDevice device)
|
||||||
: LoweringContext(name, std::forward<BackendDevice>(device)) {}
|
: LoweringContext(name, std::forward<BackendDevice>(device)) {}
|
||||||
|
|
||||||
MlirLoweringContext::MlirLoweringContext(
|
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
const std::string& name, BackendDevice device,
|
const std::string& name, BackendDevice device,
|
||||||
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
||||||
: LoweringContext(
|
: LoweringContext(
|
||||||
|
@ -30,29 +31,34 @@ MlirLoweringContext::MlirLoweringContext(
|
||||||
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||||
std::forward<Util::EmissionMap>(emit_status)) {}
|
std::forward<Util::EmissionMap>(emit_status)) {}
|
||||||
|
|
||||||
int MlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
|
int TorchMlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
|
||||||
|
|
||||||
const std::vector<torch::lazy::Shape>&
|
const std::vector<torch::lazy::Shape>&
|
||||||
MlirComputation::parameter_shapes() const {
|
TorchMlirComputation::parameter_shapes() const {
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::string>& MlirComputation::parameter_names() const {
|
const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
const torch::lazy::Shape& MlirComputation::result_shape() const {
|
const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TorchMlirComputation::to_string() const {
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the shape of the result tuple component, given by index.
|
// Get the shape of the result tuple component, given by index.
|
||||||
torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const {
|
torch::lazy::Shape TorchMlirLoweringContext::GetResultShape(size_t index) const {
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds the given output as a component of the result tuple and returns its
|
// Adds the given output as a component of the result tuple and returns its
|
||||||
// assigned position within the tuple.
|
// assigned position within the tuple.
|
||||||
size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
|
size_t TorchMlirLoweringContext::AddResult(const torch::lazy::Output& output) {
|
||||||
|
PRINT_FUNCTION();
|
||||||
const torch::lazy::Node* node;
|
const torch::lazy::Node* node;
|
||||||
auto it = emitted_outputs_.find(output);
|
auto it = emitted_outputs_.find(output);
|
||||||
if (it == emitted_outputs_.end()) {
|
if (it == emitted_outputs_.end()) {
|
||||||
|
@ -75,7 +81,7 @@ size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
|
||||||
// Associates the given output with the input parameter of the given index and
|
// Associates the given output with the input parameter of the given index and
|
||||||
// shape. Only used for the operator-by-operator execution, mostly for
|
// shape. Only used for the operator-by-operator execution, mostly for
|
||||||
// debugging purposes.
|
// debugging purposes.
|
||||||
void MlirLoweringContext::AddParameter(
|
void TorchMlirLoweringContext::AddParameter(
|
||||||
const torch::lazy::Output& output, size_t index,
|
const torch::lazy::Output& output, size_t index,
|
||||||
const torch::lazy::Shape& shape, const std::string& name) {
|
const torch::lazy::Shape& shape, const std::string& name) {
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
|
@ -83,10 +89,18 @@ void MlirLoweringContext::AddParameter(
|
||||||
|
|
||||||
// Build the computation capturing all the operations created with the
|
// Build the computation capturing all the operations created with the
|
||||||
// embedded builder (returned by the builder() API).
|
// embedded builder (returned by the builder() API).
|
||||||
ComputationPtr MlirLoweringContext::Build() {
|
ComputationPtr TorchMlirLoweringContext::Build() {
|
||||||
|
PRINT_FUNCTION()
|
||||||
for (const torch::lazy::Node* output : result_tuple_) {
|
for (const torch::lazy::Node* output : result_tuple_) {
|
||||||
}
|
}
|
||||||
return std::make_shared<MlirComputation>();
|
return std::make_shared<TorchMlirComputation>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves the lowered operation for an output. If the requested output is
|
||||||
|
// not available yet, the graph behind the output's Node is lowered, and the
|
||||||
|
// corresponding MLIR operation returned.
|
||||||
|
torch::jit::Value* GetOutputOp(const Output& output) {
|
||||||
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// This file is adapted from pytorch/pytorch
|
// This file is adapted from pytorch/pytorch
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h
|
// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
@ -19,7 +19,7 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class TORCH_API MlirComputation : public torch::lazy::Computation {
|
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
||||||
public:
|
public:
|
||||||
int parameters_size() const override;
|
int parameters_size() const override;
|
||||||
|
|
||||||
|
@ -31,11 +31,11 @@ public:
|
||||||
virtual const torch::lazy::Shape& result_shape() const override;
|
virtual const torch::lazy::Shape& result_shape() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API MlirLoweringContext : public torch::lazy::LoweringContext {
|
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
|
||||||
public:
|
public:
|
||||||
MlirLoweringContext(
|
TorchMlirLoweringContext(
|
||||||
const std::string& name, torch::lazy::BackendDevice device);
|
const std::string& name, torch::lazy::BackendDevice device);
|
||||||
MlirLoweringContext(
|
TorchMlirLoweringContext(
|
||||||
const std::string& name, torch::lazy::BackendDevice device,
|
const std::string& name, torch::lazy::BackendDevice device,
|
||||||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||||
torch::lazy::Util::EmissionMap emit_status);
|
torch::lazy::Util::EmissionMap emit_status);
|
||||||
|
@ -58,6 +58,11 @@ public:
|
||||||
// embedded builder (returned by the builder() API).
|
// embedded builder (returned by the builder() API).
|
||||||
virtual torch::lazy::ComputationPtr Build() override;
|
virtual torch::lazy::ComputationPtr Build() override;
|
||||||
|
|
||||||
|
// Retrieves the lowered operation for an output. If the requested output is
|
||||||
|
// not available yet, the graph behind the output's Node is lowered, and the
|
||||||
|
// corresponding MLIR operation returned.
|
||||||
|
torch::jit::Value* GetOutputOp(const Output& output);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<const torch::lazy::Node*> result_tuple_;
|
std::vector<const torch::lazy::Node*> result_tuple_;
|
||||||
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
|
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// This file is adapted from pytorch/pytorch
|
// This file is adapted from pytorch/pytorch
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_ltc_ts_type.cpp
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <ATen/Operators.h>
|
#include <ATen/Operators.h>
|
||||||
|
@ -17,6 +17,7 @@
|
||||||
#include <ATen/ops/result_type.h>
|
#include <ATen/ops/result_type.h>
|
||||||
#include <torch/csrc/lazy/core/helpers.h>
|
#include <torch/csrc/lazy/core/helpers.h>
|
||||||
#include <torch/csrc/lazy/core/metrics.h>
|
#include <torch/csrc/lazy/core/metrics.h>
|
||||||
|
#include <torch/csrc/lazy/core/tensor_aten_ops.h>
|
||||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||||
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
@ -24,13 +25,13 @@
|
||||||
#include "ATen/MetaFunctions.h"
|
#include "ATen/MetaFunctions.h"
|
||||||
#include <torch/csrc/lazy/core/tensor_impl.h>
|
#include <torch/csrc/lazy/core/tensor_impl.h>
|
||||||
|
|
||||||
#include "../tensor_aten_ops.h"
|
|
||||||
#include "../utils/exception.h"
|
#include "../utils/exception.h"
|
||||||
#include "../utils/sys_utils.h"
|
#include "../utils/sys_utils.h"
|
||||||
#include "LazyNativeFunctions.h"
|
#include "LazyNativeFunctions.h"
|
||||||
#include "LazyShapeInference.h"
|
#include "LazyShapeInference.h"
|
||||||
|
|
||||||
namespace torch_lazy_tensors {
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -121,7 +122,7 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
|
||||||
|
|
||||||
// UNIMPLEMENTED_FUNCTION_ERROR();
|
// UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
// // return torch::lazy::CreateAtenFromLtcTensor(
|
// // return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
// // lazy_tensor_aten_ops::bernoulli(self_tensor));
|
// // torch::lazy::bernoulli(self_tensor));
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// at::Tensor& LazyNativeFunctions::bernoulli_(
|
// at::Tensor& LazyNativeFunctions::bernoulli_(
|
||||||
|
@ -133,7 +134,7 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
|
||||||
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
|
||||||
// UNIMPLEMENTED_FUNCTION_ERROR();
|
// UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
// // lazy_tensor_aten_ops::bernoulli_(self_tensor, p);
|
// // torch::lazy::bernoulli_(self_tensor, p);
|
||||||
// // return self;
|
// // return self;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
@ -208,7 +209,7 @@ at::Tensor LazyNativeFunctions::_copy_from(
|
||||||
dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
|
dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
lazy_tensor_aten_ops::copy_(dst_tensor, self_tensor);
|
torch::lazy::copy_(dst_tensor, self_tensor);
|
||||||
auto* impl =
|
auto* impl =
|
||||||
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
|
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
|
||||||
impl->set_tensor(dst_tensor);
|
impl->set_tensor(dst_tensor);
|
||||||
|
@ -260,15 +261,15 @@ at::Tensor LazyNativeFunctions::expand(
|
||||||
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
|
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::expand(
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
// torch::lazy::TryGetLtcTensor(self), size.vec()));
|
torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec()));
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor&
|
at::Tensor&
|
||||||
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
|
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
lazy_tensor_aten_ops::fill_(self_tensor, value);
|
torch::lazy::fill_(self_tensor, value);
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,110 +281,86 @@ LazyNativeFunctions::native_batch_norm(
|
||||||
const c10::optional<at::Tensor>& running_var, bool training,
|
const c10::optional<at::Tensor>& running_var, bool training,
|
||||||
double momentum, double eps) {
|
double momentum, double eps) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
torch::lazy::LazyTensorPtr input_tensor = torch::lazy::TryGetLtcTensor(input);
|
auto input_tensor = torch::lazy::TryGetLtcTensor(input);
|
||||||
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
|
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
|
||||||
torch::lazy::LazyTensorPtr running_mean_tensor =
|
auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device);
|
||||||
GetOrCreateLtcTensor(running_mean, device);
|
auto running_var_tensor = GetOrCreateLtcTensor(running_var, device);
|
||||||
torch::lazy::LazyTensorPtr running_var_tensor =
|
auto outputs = torch::lazy::native_batch_norm(
|
||||||
GetOrCreateLtcTensor(running_var, device);
|
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
GetOrCreateLtcTensor(bias, device), running_mean_tensor,
|
||||||
// auto outputs = lazy_tensor_aten_ops::ts_native_batch_norm(
|
running_var_tensor, training, momentum, eps);
|
||||||
// torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight,
|
return std::make_tuple(
|
||||||
// device),
|
torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
|
||||||
// GetOrCreateLtcTensor(bias, device), running_mean_tensor,
|
torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
|
||||||
// running_var_tensor, training, momentum, eps);
|
torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
|
||||||
// return
|
|
||||||
// std::make_tuple(torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
|
|
||||||
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
|
|
||||||
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||||
// LazyNativeFunctions::native_batch_norm_backward(
|
LazyNativeFunctions::native_batch_norm_backward(
|
||||||
// const at::Tensor& grad_out, const at::Tensor& input,
|
const at::Tensor& grad_out, const at::Tensor& input,
|
||||||
// const c10::optional<at::Tensor>& weight,
|
const c10::optional<at::Tensor>& weight,
|
||||||
// const c10::optional<at::Tensor>& running_mean,
|
const c10::optional<at::Tensor>& running_mean,
|
||||||
// const c10::optional<at::Tensor>& running_var,
|
const c10::optional<at::Tensor>& running_var,
|
||||||
// const c10::optional<at::Tensor>& save_mean,
|
const c10::optional<at::Tensor>& save_mean,
|
||||||
// const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
|
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
|
||||||
// std::array<bool, 3> output_mask) {
|
std::array<bool, 3> output_mask) {
|
||||||
// TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
// torch::lazy::LazyTensor grad_out_tensor =
|
auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out);
|
||||||
// torch::lazy::TryGetLtcTensor(grad_out);
|
const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice();
|
||||||
// const torch::lazy::BackendDevice& device = grad_out_tensor.GetDevice();
|
torch::lazy::LazyTensorPtr null_tensor;
|
||||||
// torch::lazy::LazyTensor null_tensor;
|
bool running_stats = running_mean && running_mean->defined();
|
||||||
// bool running_stats = running_mean && running_mean->defined();
|
CHECK_EQ(running_var && running_var->defined(), running_stats);
|
||||||
// CHECK_EQ(running_var && running_var->defined(), running_stats);
|
auto gradients = torch::lazy::native_batch_norm_backward(
|
||||||
// UNIMPLEMENTED_FUNCTION_ERROR();
|
torch::lazy::TryGetLtcTensor(grad_out),
|
||||||
// // auto gradients = lazy_tensor_aten_ops::ts_native_batch_norm_backward(
|
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
|
||||||
// // torch::lazy::TryGetLtcTensor(grad_out),
|
running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor,
|
||||||
// torch::lazy::TryGetLtcTensor(input),
|
running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor,
|
||||||
// // GetOrCreateLtcTensor(weight, device),
|
GetOrCreateLtcTensor(save_mean, device),
|
||||||
// // running_stats ? GetOrCreateLtcTensor(running_mean, device)
|
GetOrCreateLtcTensor(save_invstd, device), train, eps, output_mask);
|
||||||
// // : null_tensor,
|
at::Tensor undefined;
|
||||||
// // running_stats ? GetOrCreateLtcTensor(running_var, device)
|
return std::make_tuple(
|
||||||
// // : null_tensor,
|
output_mask[0]
|
||||||
// // GetOrCreateLtcTensor(save_mean, device),
|
? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
|
||||||
// // GetOrCreateLtcTensor(save_invstd, device), train, eps,
|
: undefined,
|
||||||
// // output_mask);
|
output_mask[1]
|
||||||
// // at::Tensor undefined;
|
? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
|
||||||
// // return std::make_tuple(
|
: undefined,
|
||||||
// // output_mask[0] ?
|
output_mask[2]
|
||||||
// torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
|
? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
|
||||||
// // : undefined,
|
: undefined);
|
||||||
// // output_mask[1] ?
|
}
|
||||||
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
|
|
||||||
// // : undefined,
|
|
||||||
// // output_mask[2] ?
|
|
||||||
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
|
|
||||||
// // : undefined);
|
|
||||||
// }
|
|
||||||
|
|
||||||
at::Tensor
|
at::Tensor
|
||||||
LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
|
LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::permute(
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
// self_tensor, torch::lazy::ToI64Vector(dims)));
|
torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims)));
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor
|
|
||||||
LazyNativeFunctions::repeat(const at::Tensor& self, at::IntArrayRef repeats) {
|
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
|
||||||
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::repeat(
|
|
||||||
// torch::lazy::TryGetLtcTensor(self),
|
|
||||||
// torch::lazy::ToI64Vector(repeats)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
|
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
// return torch::lazy::CreateAtenFromLtcTensor(
|
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self)));
|
||||||
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
|
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
// return torch::lazy::CreateAtenFromLtcTensor(
|
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim));
|
||||||
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self),
|
|
||||||
// dim));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::transpose(
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
torch::lazy::TryGetLtcTensor(self), 0, 1));
|
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
|
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
// return torch::lazy::CreateAtenFromLtcTensor(
|
torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim));
|
||||||
// lazy_tensor_aten_ops::unsqueeze(torch::lazy::TryGetLtcTensor(self),
|
|
||||||
// dim));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor
|
at::Tensor
|
||||||
|
@ -391,9 +368,10 @@ LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
return torch::lazy::CreateAtenFromLtcTensor(
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
lazy_tensor_aten_ops::view(self_tensor, torch::lazy::ToI64Vector(size)));
|
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitializeAtenBindings() {}
|
void InitializeAtenBindings() {}
|
||||||
|
|
||||||
} // namespace torch_lazy_tensors
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -18,112 +18,9 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
namespace {
|
TorchMlirOpVector
|
||||||
|
TorchMlirNode::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
hash_t OperandHashes(
|
return {};
|
||||||
const OpList& operands, const hash_t& seed, const bool bakeInSizes) {
|
|
||||||
hash_t hash = seed;
|
|
||||||
for (auto& operand : operands) {
|
|
||||||
if (!operand) {
|
|
||||||
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto operand_hash =
|
|
||||||
bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
|
|
||||||
hash = HashCombine(hash, operand_hash);
|
|
||||||
}
|
|
||||||
return hash;
|
|
||||||
}
|
|
||||||
|
|
||||||
hash_t GetOpHash(
|
|
||||||
OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) {
|
|
||||||
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
|
|
||||||
return HashCombine(h, hash_seed);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
MlirNode::MlirNode(
|
|
||||||
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
|
|
||||||
hash_t hash_seed)
|
|
||||||
: Node(
|
|
||||||
op, num_outputs,
|
|
||||||
/* node_hash */ HashCombine(op.hash(), hash_seed),
|
|
||||||
/* dag_hash */
|
|
||||||
[&](bool bakeInSizes) -> hash_t {
|
|
||||||
return OperandHashes(
|
|
||||||
operands, HashCombine(op.hash(), hash_seed), bakeInSizes);
|
|
||||||
}),
|
|
||||||
shapes_(std::move(shapes)) {
|
|
||||||
for (auto& operand : operands) {
|
|
||||||
// Ideally, optional operands should be filtered by the leaf node classes,
|
|
||||||
// but it's just much easier to do it here.
|
|
||||||
if (!operand) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
AddOperand(operand.node, operand.index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirNode::MlirNode(
|
|
||||||
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
|
||||||
size_t num_outputs, hash_t hash_seed)
|
|
||||||
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
|
|
||||||
shapes_.push_back(GetOpShape(shape_fn));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirNode::MlirNode(
|
|
||||||
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
|
|
||||||
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
|
|
||||||
|
|
||||||
void MlirNode::SetShapeDeferred(const std::function<Shape()>& shape_fn) {
|
|
||||||
shapes_.push_back(GetOpShape(shape_fn));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
|
|
||||||
: Node(op, num_outputs, [&](bool bakeInSizes) -> hash_t {
|
|
||||||
return GetOpHash(op, shape, hash_seed, bakeInSizes);
|
|
||||||
}) {
|
|
||||||
shapes_.push_back(std::move(shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
using ShapeCache = Cache<hash_t, Shape, HashReducer>;
|
|
||||||
|
|
||||||
constexpr const int torch_lazy_shape_cache_size = 4096;
|
|
||||||
|
|
||||||
ShapeCache* GetShapeCache() {
|
|
||||||
static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size);
|
|
||||||
return cache;
|
|
||||||
}
|
|
||||||
|
|
||||||
Shape MlirNode::GetOpShape(const std::function<Shape()>& shape_fn) const {
|
|
||||||
ShapeCache* shape_cache = GetShapeCache();
|
|
||||||
auto shape = shape_cache->Get(hash());
|
|
||||||
if (shape == nullptr) {
|
|
||||||
shape = shape_cache->Add(hash(), std::make_shared<Shape>(shape_fn()));
|
|
||||||
}
|
|
||||||
return *shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::ArrayRef<Shape> MlirNode::shapes() const { return shapes_; }
|
|
||||||
|
|
||||||
const Shape& MlirNode::shape(size_t output_index) const {
|
|
||||||
return shapes_.at(output_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<Output>& MlirNode::operands() const {
|
|
||||||
return operands_as_outputs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Output& MlirNode::operand(size_t i) const {
|
|
||||||
return operands_as_outputs_.at(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
void MlirNode::AddOperand(NodePtr node, size_t index) {
|
|
||||||
CHECK_LT(index, node->num_outputs());
|
|
||||||
operands_.push_back(std::move(node));
|
|
||||||
operands_as_outputs_.emplace_back(operands_.back().get(), index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
|
|
|
@ -7,76 +7,33 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// This file is adapted from pytorch/pytorch
|
// This file is adapted from pytorch/pytorch
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.h
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/interned_strings.h>
|
#include <ATen/core/interned_strings.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
#include <torch/csrc/lazy/core/ir.h>
|
#include <torch/csrc/lazy/core/ir.h>
|
||||||
#include <torch/csrc/lazy/core/shape.h>
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
|
#include "../utils/debug.h"
|
||||||
#include "../utils/exception.h"
|
#include "../utils/exception.h"
|
||||||
#include "aten_eager_fallback.h"
|
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
typedef std::vector<NodePtr> MlirOpVector;
|
typedef std::vector<torch::jit::Value*> TorchMlirOpVector;
|
||||||
typedef NodePtr MlirFunction;
|
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
|
||||||
|
|
||||||
class TORCH_API MlirNode : public torch::lazy::Node {
|
|
||||||
|
|
||||||
|
class TORCH_API TorchMlirNode : public torch::lazy::Node {
|
||||||
public:
|
public:
|
||||||
MlirNode(
|
using torch::lazy::Node::Node;
|
||||||
OpKind op, OpList operands, std::vector<Shape>&& shapes,
|
|
||||||
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
|
|
||||||
|
|
||||||
// Same as the constructor above, but the shape is generated by a function,
|
virtual TorchMlirOpVector
|
||||||
// only if needed (shape cache miss).
|
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
|
||||||
MlirNode(
|
|
||||||
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
|
||||||
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
|
|
||||||
|
|
||||||
// The shape is set later.
|
|
||||||
MlirNode(
|
|
||||||
OpKind op, OpList operands, size_t num_outputs = 1,
|
|
||||||
hash_t hash_seed = kHashSeed);
|
|
||||||
|
|
||||||
void SetShapeDeferred(const std::function<Shape()>& shape_fn);
|
|
||||||
|
|
||||||
// Contructor used to create leaf nodes.
|
|
||||||
MlirNode(
|
|
||||||
OpKind op, Shape shape, size_t num_outputs = 1,
|
|
||||||
hash_t hash_seed = kHashSeed);
|
|
||||||
|
|
||||||
Shape GetOpShape(const std::function<Shape()>& shape_fn) const;
|
|
||||||
|
|
||||||
// Retrieves the full shape of the IR Node.
|
|
||||||
c10::ArrayRef<Shape> shapes() const override;
|
|
||||||
|
|
||||||
// Retrieves the shape of the output at a given index.
|
|
||||||
const Shape& shape(size_t output_index = 0) const override;
|
|
||||||
|
|
||||||
const std::vector<Output>& operands() const override;
|
|
||||||
|
|
||||||
const Output& operand(size_t i) const override;
|
|
||||||
|
|
||||||
virtual MlirOpVector
|
|
||||||
Lower(MlirFunction function, MlirLoweringContext* loctx) const = 0;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Adds node's index output number as operand.
|
|
||||||
void AddOperand(NodePtr node, size_t index = 0);
|
|
||||||
|
|
||||||
std::vector<Shape> shapes_;
|
|
||||||
// A node holds a real reference to its operands.
|
|
||||||
std::vector<NodePtr> operands_;
|
|
||||||
// Outputs do not hold references on the nodes, and neither do the uses, since
|
|
||||||
// otherwise we get into circular reference counting.
|
|
||||||
std::vector<Output> operands_as_outputs_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
|
|
|
@ -1,242 +0,0 @@
|
||||||
//===- tensor_aten_ops.cpp ------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// This file is adapted from pytorch/pytorch
|
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.cpp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
#include <ATen/InferSize.h>
|
|
||||||
#include <c10/util/Optional.h>
|
|
||||||
#include <torch/csrc/autograd/variable.h>
|
|
||||||
#include <torch/csrc/lazy/core/helpers.h>
|
|
||||||
#include <torch/csrc/lazy/core/ir_util.h>
|
|
||||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
|
||||||
#include <torch/csrc/lazy/core/metrics.h>
|
|
||||||
#include <torch/csrc/lazy/core/tensor.h>
|
|
||||||
#include <torch/csrc/lazy/core/util.h>
|
|
||||||
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
|
||||||
#include <torch/csrc/lazy/core/view_ops/permute.h>
|
|
||||||
#include <torch/csrc/lazy/core/view_ops/view.h>
|
|
||||||
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
|
|
||||||
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
|
|
||||||
|
|
||||||
#include "tensor_aten_ops.h"
|
|
||||||
|
|
||||||
namespace torch_lazy_tensors {
|
|
||||||
namespace lazy_tensor_aten_ops {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// to enable operator+-*/ for Value
|
|
||||||
using namespace torch::lazy;
|
|
||||||
|
|
||||||
torch::lazy::Value MaybeExpand(
|
|
||||||
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
|
|
||||||
if (input.shape().sizes() == target_shape.sizes()) {
|
|
||||||
return input;
|
|
||||||
}
|
|
||||||
return torch::lazy::MakeNode<torch::lazy::Expand>(
|
|
||||||
input, target_shape.sizes().vec(),
|
|
||||||
/*is_scalar_expand=*/false);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> GetExpandDimensions(
|
|
||||||
const torch::lazy::Shape& shape, std::vector<int64_t> dimensions) {
|
|
||||||
CHECK_GE(dimensions.size(), shape.dim()) << shape;
|
|
||||||
int64_t base = dimensions.size() - shape.dim();
|
|
||||||
for (size_t i = 0; i < shape.dim(); ++i) {
|
|
||||||
if (dimensions[base + i] == -1) {
|
|
||||||
dimensions[base + i] = shape.size(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dimensions;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
|
|
||||||
torch::lazy::Shape
|
|
||||||
BatchNormFeaturesShape(const torch::lazy::LazyTensorPtr& input) {
|
|
||||||
CHECK(input);
|
|
||||||
auto input_shape = input->shape().Get();
|
|
||||||
return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the IR for the given input or the provided default value broadcasted
|
|
||||||
// to the default shape, if the input is undefined.
|
|
||||||
torch::lazy::Value GetIrValueOrDefault(
|
|
||||||
const torch::lazy::LazyTensorPtr& input, const at::Scalar& default_value,
|
|
||||||
const torch::lazy::Shape& default_shape,
|
|
||||||
const torch::lazy::BackendDevice& device) {
|
|
||||||
return input ? input->GetIrValue()
|
|
||||||
: torch::lazy::LazyGraphExecutor::Get()
|
|
||||||
->GetIrValueForExpandedScalar(
|
|
||||||
default_value, default_shape, device);
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::lazy::ViewInfo CreateAsStridedViewInfo(
|
|
||||||
const torch::lazy::Shape& input_shape, std::vector<int64_t> size,
|
|
||||||
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
|
|
||||||
torch::lazy::Shape result_shape =
|
|
||||||
torch::lazy::Shape(input_shape.scalar_type(), size);
|
|
||||||
torch::lazy::AsStridedInfo as_strided_info;
|
|
||||||
as_strided_info.stride = std::move(stride);
|
|
||||||
if (storage_offset) {
|
|
||||||
as_strided_info.offset = *storage_offset;
|
|
||||||
}
|
|
||||||
return torch::lazy::ViewInfo(
|
|
||||||
torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape),
|
|
||||||
input_shape, std::move(as_strided_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
// ATEN operators follows here, listed in alphabetical order.
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
torch::lazy::LazyTensorPtr as_strided(
|
|
||||||
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
|
||||||
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
return input->CreateViewTensor(CreateAsStridedViewInfo(
|
|
||||||
input_shape, std::move(size), std::move(stride), storage_offset));
|
|
||||||
}
|
|
||||||
|
|
||||||
void as_strided_(
|
|
||||||
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
|
||||||
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
|
|
||||||
if (input->data()->view == nullptr) {
|
|
||||||
input->SetIrValue(torch::lazy::MakeNode<torch::lazy::AsStrided>(
|
|
||||||
input->GetIrValue(), std::move(size), std::move(stride),
|
|
||||||
storage_offset.value_or(0)));
|
|
||||||
} else {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
input->SetSubView(CreateAsStridedViewInfo(
|
|
||||||
input_shape, std::move(size), std::move(stride), storage_offset));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr
|
|
||||||
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size) {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
return torch::lazy::LazyTensor::Create(
|
|
||||||
torch::lazy::MakeNode<torch::lazy::Expand>(
|
|
||||||
input->GetIrValue(),
|
|
||||||
GetExpandDimensions(input_shape.Get(), std::move(size)),
|
|
||||||
/*is_scalar_expand=*/false),
|
|
||||||
input->GetDevice());
|
|
||||||
}
|
|
||||||
|
|
||||||
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) {
|
|
||||||
torch::lazy::Value constant =
|
|
||||||
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
|
|
||||||
value, input->shape(), input->GetDevice());
|
|
||||||
input->SetInPlaceIrValue(std::move(constant));
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr narrow(
|
|
||||||
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
|
||||||
int64_t length) {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
|
|
||||||
torch::lazy::Shape narrow_shape = input_shape;
|
|
||||||
narrow_shape.set_size(dim, length);
|
|
||||||
|
|
||||||
torch::lazy::ViewInfo::Type view_type =
|
|
||||||
(input_shape.Get().numel() == narrow_shape.numel())
|
|
||||||
? torch::lazy::ViewInfo::Type::kReshape
|
|
||||||
: torch::lazy::ViewInfo::Type::kNarrow;
|
|
||||||
torch::lazy::ViewInfo view_info(
|
|
||||||
view_type, std::move(narrow_shape), input_shape);
|
|
||||||
view_info.indices[dim] =
|
|
||||||
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
|
|
||||||
return input->CreateViewTensor(std::move(view_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr
|
|
||||||
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims) {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
torch::lazy::ViewInfo view_info(
|
|
||||||
torch::lazy::ViewInfo::Type::kPermute, input_shape,
|
|
||||||
torch::lazy::GetCanonicalDimensionIndices(dims, input_shape.Get().dim()));
|
|
||||||
return input->CreateViewTensor(std::move(view_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
|
|
||||||
if (input->GetDevice() == src->GetDevice()) {
|
|
||||||
torch::lazy::Value copy_value;
|
|
||||||
if (input->dtype() == src->dtype()) {
|
|
||||||
copy_value = src->GetIrValue();
|
|
||||||
} else {
|
|
||||||
copy_value = torch::lazy::MakeNode<torch::lazy::Cast>(
|
|
||||||
src->GetIrValue(), input->dtype(), src->dtype());
|
|
||||||
}
|
|
||||||
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
|
|
||||||
} else {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
at::Tensor src_tensor = src->ToTensor(/*detached=*/true);
|
|
||||||
if (src_tensor.sizes() != input_shape.Get().sizes()) {
|
|
||||||
src_tensor = src_tensor.expand(input_shape.Get().sizes().vec());
|
|
||||||
}
|
|
||||||
input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr slice(
|
|
||||||
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
|
||||||
int64_t end, int64_t step) {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
|
|
||||||
start =
|
|
||||||
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
|
|
||||||
end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end);
|
|
||||||
// PyTorch allows tensor[-1:0] to return a 0-dim tensor.
|
|
||||||
if (start > end) {
|
|
||||||
end = start;
|
|
||||||
}
|
|
||||||
step = std::min(step, end - start);
|
|
||||||
|
|
||||||
torch::lazy::SelectInfo select = {dim, start, end, step};
|
|
||||||
torch::lazy::ViewInfo view_info(
|
|
||||||
torch::lazy::ViewInfo::Type::kSelect, input_shape, std::move(select));
|
|
||||||
return input->CreateViewTensor(std::move(view_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr
|
|
||||||
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
auto permute_dims = torch::lazy::MakeTransposePermutation(
|
|
||||||
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
|
|
||||||
torch::lazy::ViewInfo view_info(
|
|
||||||
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
|
|
||||||
return input->CreateViewTensor(std::move(view_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
|
|
||||||
auto input_shape = input->shape();
|
|
||||||
auto permute_dims = torch::lazy::MakeTransposePermutation(
|
|
||||||
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
|
|
||||||
torch::lazy::ViewInfo view_info(
|
|
||||||
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
|
|
||||||
return input->ModifyCurrentView(std::move(view_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr view(
|
|
||||||
const torch::lazy::LazyTensorPtr& input,
|
|
||||||
c10::ArrayRef<int64_t> output_size) {
|
|
||||||
auto input_shape = input->shape().Get();
|
|
||||||
torch::lazy::Shape shape = torch::lazy::Shape(
|
|
||||||
input_shape.scalar_type(),
|
|
||||||
at::infer_size(output_size, input_shape.numel()));
|
|
||||||
torch::lazy::ViewInfo view_info(
|
|
||||||
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
|
|
||||||
return input->CreateViewTensor(std::move(view_info));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace lazy_tensor_aten_ops
|
|
||||||
} // namespace torch_lazy_tensors
|
|
|
@ -1,79 +0,0 @@
|
||||||
//===- tensor_aten_ops.h --------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// This file is adapted from pytorch/pytorch
|
|
||||||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.h
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/csrc/lazy/core/tensor.h>
|
|
||||||
|
|
||||||
namespace torch_lazy_tensors {
|
|
||||||
namespace lazy_tensor_aten_ops {
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
// ATEN operators follows here, listed in alphabetical order.
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Takes a slice from the input as R1 at the specified offset and reshapes it
|
|
||||||
// into the provided size.
|
|
||||||
torch::lazy::LazyTensorPtr as_strided(
|
|
||||||
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
|
||||||
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
|
|
||||||
|
|
||||||
// In-place version of the method above.
|
|
||||||
void as_strided_(
|
|
||||||
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
|
|
||||||
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr
|
|
||||||
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size);
|
|
||||||
|
|
||||||
// Fills the input with the given value.
|
|
||||||
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value);
|
|
||||||
|
|
||||||
// Returns a new tensor that is a narrowed view of the input in the given
|
|
||||||
// dimension.
|
|
||||||
torch::lazy::LazyTensorPtr narrow(
|
|
||||||
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
|
||||||
int64_t length);
|
|
||||||
|
|
||||||
// Permute the dimensions of this tensor according to the given permutation.
|
|
||||||
torch::lazy::LazyTensorPtr
|
|
||||||
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims);
|
|
||||||
|
|
||||||
// Repeats the input tensor along each dimension by the given number of
|
|
||||||
// repeats.
|
|
||||||
torch::lazy::LazyTensorPtr
|
|
||||||
repeat(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> repeats);
|
|
||||||
|
|
||||||
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src);
|
|
||||||
|
|
||||||
torch::lazy::LazyTensorPtr slice(
|
|
||||||
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
|
||||||
int64_t end, int64_t step);
|
|
||||||
|
|
||||||
std::tuple<
|
|
||||||
torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr,
|
|
||||||
torch::lazy::LazyTensorPtr>
|
|
||||||
svd(const torch::lazy::LazyTensorPtr& input, bool some, bool compute_uv);
|
|
||||||
|
|
||||||
// Swap given dimensions of the input.
|
|
||||||
torch::lazy::LazyTensorPtr
|
|
||||||
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
|
|
||||||
|
|
||||||
// In-place version of the method above.
|
|
||||||
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
|
|
||||||
|
|
||||||
// Like reshape, but it returns a view into the original tensor.
|
|
||||||
torch::lazy::LazyTensorPtr view(
|
|
||||||
const torch::lazy::LazyTensorPtr& input,
|
|
||||||
c10::ArrayRef<int64_t> output_size);
|
|
||||||
|
|
||||||
} // namespace lazy_tensor_aten_ops
|
|
||||||
} // namespace torch_lazy_tensors
|
|
Loading…
Reference in New Issue