Fix Torch-MLIR LTC Backend based off latest PyTorch master (#723)

* Changes as a result of the LTC TS backend decoupling

* Fix bugs in BackendImpl and codegen

* Fix based on latest PyTorch master
pull/1125/head
Jae Hoon (Antonio) Kim 2022-04-13 15:42:02 -04:00 committed by Henry Tu
parent c3b20e444c
commit 65cf1465ef
16 changed files with 218 additions and 737 deletions

2
.gitignore vendored
View File

@ -26,7 +26,7 @@ bazel-*
# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/backend/LazyLazyIr.h
/python/torch_mlir/csrc/backend/LazyIr.h
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp

View File

@ -24,6 +24,10 @@ from codegen.gen import get_grouped_native_functions, parse_native_yaml
from codegen.model import NativeFunctionsGroup
def isOptionalCType(arg):
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"
def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
@ -98,7 +102,7 @@ def generate_native_functions(
yaml.dump(
{
"backend": "Lazy",
"cpp_namespace": "torch_lazy_tensors",
"cpp_namespace": "torch::lazy",
"full_codegen": opnames,
"supported": sorted(supported_ops),
},
@ -120,21 +124,46 @@ def generate_native_functions(
@dataclass(frozen=True)
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
lowering_function_type: str = "torch::lazy::MlirFunction"
lowering_context_type: str = "torch::lazy::MlirLoweringContext*"
lowering_return_type: str = "torch::lazy::MlirOpVector"
def lowering_body(self, f):
def lowering_function(self, f):
func = (
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
)
schema = LazyIrSchema(func)
emplace_arguments = []
for arg in schema.positional_args:
if arg.is_lazy_value:
if isOptionalCType(arg.lazy_type):
emplace_arguments.append(f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr")
continue
emplace_arguments.append('loctx->GetOutputOp(operand(i++))')
continue
emplace_arguments.append(f'"{arg.name}", {arg.name}')
emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments])
emplace_kwarg_values = [f'"{t.name}", loctx->GetOutputOp(operand(i++))' for t in schema.keyword_values]
emplace_kwarg_scalars = [f'"{t.name}", {t.name}' for t in schema.keyword_scalars]
emplace_kwarguments = "\n ".join(
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])
return f"""
UNIMPLEMENTED_ERROR(
"'{func}' lowering not yet implemented"
);
""".rstrip()
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override {{
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
return {schema.aten_name}_out;
}}
""".strip()
def generate_backend(
@ -151,14 +180,14 @@ def generate_backend(
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
codegen.gen_lazy_tensor.run(
codegen.gen_lazy_tensor.run_gen_lazy_tensor(
backend_name="TorchMlir",
aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")),
source_yaml=str(source_yaml),
output_dir=str(backend_path),
dry_run=False,
impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")),
gen_ts_lowerings=False,
node_base="torch::lazy::MlirNode",
impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")),
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
@ -298,7 +327,6 @@ def main(args):
new_hash = m.hexdigest().strip()
if args.force or new_hash != prev_hash:
hash_file.write_text(new_hash)
parsed_yaml, grouped_native_functions = generate_native_functions(
config_path, torch_ops_file, native_functions
)
@ -310,6 +338,8 @@ def main(args):
grouped_native_functions,
)
hash_file.write_text(new_hash)
if __name__ == "__main__":
parser = argparse.ArgumentParser()

View File

@ -37,10 +37,9 @@ supported:
- empty
- expand
- fill_
# - native_batch_norm_backward
- native_batch_norm
# - native_batch_norm_backward
- permute
- repeat
- squeeze
- t
- unsqueeze
@ -50,3 +49,4 @@ additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
- _copy_from
- _copy_from_and_resize
- native_batch_norm_backward

View File

@ -20,16 +20,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(torch_mlir_ltc_backend SHARED
backend/aten_eager_fallback.cpp
backend/aten_ltc_mlir_type.cpp
backend/backend_impl.cpp
backend/LazyNativeFunctions.cpp
backend/LazyShapeInference.cpp
backend/GenLazyShapeInference.cpp
backend/mlir_lowering_context.cpp
backend/mlir_native_functions.cpp
backend/mlir_node.cpp
backend/RegisterLazy.cpp
tensor_aten_ops.cpp
)
target_link_libraries(torch_mlir_ltc_backend

View File

@ -69,6 +69,7 @@ TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_relu_(at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
TORCH_API std::vector<Shape> compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape);
TORCH_API std::vector<Shape> compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index);

View File

@ -1,58 +0,0 @@
//===- aten_eager_fallback.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.cpp
//===----------------------------------------------------------------------===//
#include <iostream>
#include <unordered_map>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/metrics.h>
#include "../utils/exception.h"
#include "aten_eager_fallback.h"
namespace torch_lazy_tensors {
static std::unordered_map<std::string, ::torch::lazy::Counter*>
_eager_fallback_counters;
bool force_eager_fallback(c10::Symbol op) {
static char* force_str = std::getenv("LTC_FORCE_FALLBACK");
if (force_str != nullptr) {
static auto force_sym = c10::Symbol::fromQualString(std::string(force_str));
if (op == force_sym) {
std::cout << "MLIR force_eager_fallback(" << force_str << "): true"
<< std::endl;
return true;
}
}
std::cout << "MLIR force_eager_fallback(" << op.toQualString() << "): false"
<< std::endl;
return false;
}
void ltc_eager_fallback(
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto name = c10::toString(op.operator_name());
UNSUPPORTED_ERROR(
"MLIR ltc_eager_fallback is not supported for op: " << name);
}
std::function<void(void)> register_mlir_ltc_eager_fallback;
TORCH_LIBRARY_IMPL(_, Lazy, m) {
register_mlir_ltc_eager_fallback = [&]() {
m.fallback(
torch::CppFunction::makeFromBoxedFunction<&ltc_eager_fallback>());
};
}
} // namespace torch_lazy_tensors

View File

@ -1,27 +0,0 @@
//===- aten_eager_fallback.h ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// Facilitates eager fallback behaviour
//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.h
//===----------------------------------------------------------------------===//
#pragma once
#include <ATen/native/CPUFallback.h>
namespace torch_lazy_tensors {
bool force_eager_fallback(c10::Symbol op);
void ltc_eager_fallback(
const c10::OperatorHandle& op, torch::jit::Stack* stack);
extern TORCH_API std::function<void(void)> register_mlir_ltc_eager_fallback;
} // namespace torch_lazy_tensors

View File

@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.cpp
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
//===----------------------------------------------------------------------===//
#include <torch/csrc/lazy/backend/backend_data.h>
@ -23,77 +23,79 @@
namespace torch {
namespace lazy {
MlirBackendData::MlirBackendData(BackendDevice device, Shape shape)
: BackendData(device, shape) {
TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>()) {
PRINT_FUNCTION();
auto info = std::make_shared<MlirBackendData::Info>();
SetInfo(info);
}
MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})) {
TorchMlirBackendData::TorchMlirBackendData(
const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})),
info_(std::make_unique<TorchMlirBackendData::Info>(scalar)) {
PRINT_FUNCTION();
auto info = std::make_shared<MlirBackendData::Info>(scalar);
SetInfo(info);
}
MlirBackendData::MlirBackendData(
TorchMlirBackendData::TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape)
: BackendData(device, shape) {
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>(tensor)) {
PRINT_FUNCTION();
auto info = std::make_shared<MlirBackendData::Info>(tensor);
SetInfo(info);
}
BackendData::Handle MlirBackendData::GetHandle() {
BackendData::Handle TorchMlirBackendData::GetHandle() {
return reinterpret_cast<int64_t>(this);
}
void MlirBackendData::Assign(const BackendData& data) {
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data.info());
void TorchMlirBackendData::Assign(const BackendData& data) {
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(data.info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
auto new_info = std::make_shared<MlirBackendData::Info>(*info);
SetInfo(new_info);
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
info_ = std::make_unique<TorchMlirBackendData::Info>(*info);
}
bool MlirBackendData::HasValue() const { return bool(info()); }
bool TorchMlirBackendData::HasValue() const { return bool(info_); }
TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
return info_.get();
}
/**
* Initialization/Teardown
* */
void MlirBackendImpl::PrepareToExit() const {}
void TorchMlirBackendImpl::PrepareToExit() const {}
/**
* Data Transfer
* */
BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor(
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor(
const at::Tensor& tensor, const Shape& shape,
const BackendDevice& device) const {
PRINT_FUNCTION();
return std::make_shared<MlirBackendData>(tensor, device, shape);
return std::make_shared<TorchMlirBackendData>(tensor, device, shape);
}
BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar(
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar(
const at::Scalar& scalar, const BackendDevice& device) const {
PRINT_FUNCTION();
return std::make_shared<MlirBackendData>(scalar, device);
return std::make_shared<TorchMlirBackendData>(scalar, device);
}
BackendDataPtr MlirBackendImpl::CreateDataPlaceholder(
BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const {
PRINT_FUNCTION();
return std::make_shared<MlirBackendData>(device, shape);
return std::make_shared<TorchMlirBackendData>(device, shape);
}
at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {
PRINT_FUNCTION();
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data->info());
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(data->info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info.");
info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
return info->tensor;
}
@ -101,20 +103,20 @@ at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
* Lowering, Compilation, Execution
* */
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<Node*> post_order, Util::EmissionMap emit_status) const {
PRINT_FUNCTION();
return std::make_unique<MlirLoweringContext>(
return std::make_unique<TorchMlirLoweringContext>(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status));
}
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device) const {
PRINT_FUNCTION();
return std::make_unique<MlirLoweringContext>(
return std::make_unique<TorchMlirLoweringContext>(
name, std::forward<BackendDevice>(device));
}
@ -129,13 +131,13 @@ std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const {
at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const {
PRINT_FUNCTION();
return at::DeviceType::CPU;
}
// Query all available backend devices
std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
std::vector<BackendDevice> TorchMlirBackendImpl::GetBackendDevices() const {
PRINT_FUNCTION();
return {
GetBackendDevice(c10::Device(c10::kLazy, 0)),
@ -148,7 +150,7 @@ std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const {
BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
PRINT_FUNCTION();
return BackendDevice(GetDefaultDeviceType(), device.index());
}

View File

@ -10,7 +10,7 @@
// using the Torch-MLIR ATen dialect
//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.h
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.h
//===----------------------------------------------------------------------===//
#pragma once
@ -23,7 +23,7 @@
namespace torch {
namespace lazy {
class TORCH_API MlirBackendData : public BackendData {
class TORCH_API TorchMlirBackendData : public BackendData {
public:
struct Info : public BackendData::Info {
at::Tensor tensor;
@ -39,20 +39,25 @@ public:
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
};
MlirBackendData(BackendDevice device, Shape shape);
MlirBackendData(const at::Scalar& scalar, BackendDevice device);
MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
TorchMlirBackendData(BackendDevice device, Shape shape);
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
TorchMlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
virtual BackendData::Handle GetHandle() override;
virtual void Assign(const BackendData& data) override;
virtual bool HasValue() const override;
TorchMlirBackendData::Info* mlir_info() const;
private:
std::unique_ptr<TorchMlirBackendData::Info> info_;
};
class TORCH_API MlirBackendImpl : public BackendImplInterface {
class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
public:
virtual ~MlirBackendImpl() = default;
virtual ~TorchMlirBackendImpl() = default;
/**
* Initialization/Teardown

View File

@ -7,22 +7,23 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/ts_lowering_context.cpp
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
//===----------------------------------------------------------------------===//
#include <iostream>
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "mlir_lowering_context.h"
namespace torch {
namespace lazy {
MlirLoweringContext::MlirLoweringContext(
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)) {}
MlirLoweringContext::MlirLoweringContext(
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
: LoweringContext(
@ -30,29 +31,34 @@ MlirLoweringContext::MlirLoweringContext(
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)) {}
int MlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
int TorchMlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
const std::vector<torch::lazy::Shape>&
MlirComputation::parameter_shapes() const {
TorchMlirComputation::parameter_shapes() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
const std::vector<std::string>& MlirComputation::parameter_names() const {
const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
const torch::lazy::Shape& MlirComputation::result_shape() const {
const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
std::string TorchMlirComputation::to_string() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Get the shape of the result tuple component, given by index.
torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const {
torch::lazy::Shape TorchMlirLoweringContext::GetResultShape(size_t index) const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
size_t TorchMlirLoweringContext::AddResult(const torch::lazy::Output& output) {
PRINT_FUNCTION();
const torch::lazy::Node* node;
auto it = emitted_outputs_.find(output);
if (it == emitted_outputs_.end()) {
@ -75,7 +81,7 @@ size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void MlirLoweringContext::AddParameter(
void TorchMlirLoweringContext::AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) {
UNIMPLEMENTED_FUNCTION_ERROR();
@ -83,10 +89,18 @@ void MlirLoweringContext::AddParameter(
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
ComputationPtr MlirLoweringContext::Build() {
ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION()
for (const torch::lazy::Node* output : result_tuple_) {
}
return std::make_shared<MlirComputation>();
return std::make_shared<TorchMlirComputation>();
}
// Retrieves the lowered operation for an output. If the requested output is
// not available yet, the graph behind the output's Node is lowered, and the
// corresponding MLIR operation returned.
torch::jit::Value* GetOutputOp(const Output& output) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
} // namespace lazy

View File

@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h
// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.h
//===----------------------------------------------------------------------===//
#pragma once
@ -19,7 +19,7 @@
namespace torch {
namespace lazy {
class TORCH_API MlirComputation : public torch::lazy::Computation {
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
public:
int parameters_size() const override;
@ -31,11 +31,11 @@ public:
virtual const torch::lazy::Shape& result_shape() const override;
};
class TORCH_API MlirLoweringContext : public torch::lazy::LoweringContext {
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
public:
MlirLoweringContext(
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device);
MlirLoweringContext(
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);
@ -58,6 +58,11 @@ public:
// embedded builder (returned by the builder() API).
virtual torch::lazy::ComputationPtr Build() override;
// Retrieves the lowered operation for an output. If the requested output is
// not available yet, the graph behind the output's Node is lowered, and the
// corresponding MLIR operation returned.
torch::jit::Value* GetOutputOp(const Output& output);
private:
std::vector<const torch::lazy::Node*> result_tuple_;
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;

View File

@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_ltc_ts_type.cpp
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
//===----------------------------------------------------------------------===//
#include <ATen/Operators.h>
@ -17,6 +17,7 @@
#include <ATen/ops/result_type.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <torch/csrc/lazy/core/tensor_aten_ops.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
#include <torch/library.h>
@ -24,13 +25,13 @@
#include "ATen/MetaFunctions.h"
#include <torch/csrc/lazy/core/tensor_impl.h>
#include "../tensor_aten_ops.h"
#include "../utils/exception.h"
#include "../utils/sys_utils.h"
#include "LazyNativeFunctions.h"
#include "LazyShapeInference.h"
namespace torch_lazy_tensors {
namespace torch {
namespace lazy {
namespace {
@ -121,7 +122,7 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
// UNIMPLEMENTED_FUNCTION_ERROR();
// // return torch::lazy::CreateAtenFromLtcTensor(
// // lazy_tensor_aten_ops::bernoulli(self_tensor));
// // torch::lazy::bernoulli(self_tensor));
// }
// at::Tensor& LazyNativeFunctions::bernoulli_(
@ -133,7 +134,7 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
// UNIMPLEMENTED_FUNCTION_ERROR();
// // lazy_tensor_aten_ops::bernoulli_(self_tensor, p);
// // torch::lazy::bernoulli_(self_tensor, p);
// // return self;
// }
@ -208,7 +209,7 @@ at::Tensor LazyNativeFunctions::_copy_from(
dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
}
} else {
lazy_tensor_aten_ops::copy_(dst_tensor, self_tensor);
torch::lazy::copy_(dst_tensor, self_tensor);
auto* impl =
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
impl->set_tensor(dst_tensor);
@ -260,15 +261,15 @@ at::Tensor LazyNativeFunctions::expand(
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::expand(
// torch::lazy::TryGetLtcTensor(self), size.vec()));
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec()));
}
at::Tensor&
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
lazy_tensor_aten_ops::fill_(self_tensor, value);
torch::lazy::fill_(self_tensor, value);
return self;
}
@ -280,110 +281,86 @@ LazyNativeFunctions::native_batch_norm(
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr input_tensor = torch::lazy::TryGetLtcTensor(input);
auto input_tensor = torch::lazy::TryGetLtcTensor(input);
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
torch::lazy::LazyTensorPtr running_mean_tensor =
GetOrCreateLtcTensor(running_mean, device);
torch::lazy::LazyTensorPtr running_var_tensor =
GetOrCreateLtcTensor(running_var, device);
UNIMPLEMENTED_FUNCTION_ERROR();
// auto outputs = lazy_tensor_aten_ops::ts_native_batch_norm(
// torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight,
// device),
// GetOrCreateLtcTensor(bias, device), running_mean_tensor,
// running_var_tensor, training, momentum, eps);
// return
// std::make_tuple(torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device);
auto running_var_tensor = GetOrCreateLtcTensor(running_var, device);
auto outputs = torch::lazy::native_batch_norm(
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
GetOrCreateLtcTensor(bias, device), running_mean_tensor,
running_var_tensor, training, momentum, eps);
return std::make_tuple(
torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
}
// std::tuple<at::Tensor, at::Tensor, at::Tensor>
// LazyNativeFunctions::native_batch_norm_backward(
// const at::Tensor& grad_out, const at::Tensor& input,
// const c10::optional<at::Tensor>& weight,
// const c10::optional<at::Tensor>& running_mean,
// const c10::optional<at::Tensor>& running_var,
// const c10::optional<at::Tensor>& save_mean,
// const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
// std::array<bool, 3> output_mask) {
// TORCH_LAZY_FN_COUNTER("lazy::");
// torch::lazy::LazyTensor grad_out_tensor =
// torch::lazy::TryGetLtcTensor(grad_out);
// const torch::lazy::BackendDevice& device = grad_out_tensor.GetDevice();
// torch::lazy::LazyTensor null_tensor;
// bool running_stats = running_mean && running_mean->defined();
// CHECK_EQ(running_var && running_var->defined(), running_stats);
// UNIMPLEMENTED_FUNCTION_ERROR();
// // auto gradients = lazy_tensor_aten_ops::ts_native_batch_norm_backward(
// // torch::lazy::TryGetLtcTensor(grad_out),
// torch::lazy::TryGetLtcTensor(input),
// // GetOrCreateLtcTensor(weight, device),
// // running_stats ? GetOrCreateLtcTensor(running_mean, device)
// // : null_tensor,
// // running_stats ? GetOrCreateLtcTensor(running_var, device)
// // : null_tensor,
// // GetOrCreateLtcTensor(save_mean, device),
// // GetOrCreateLtcTensor(save_invstd, device), train, eps,
// // output_mask);
// // at::Tensor undefined;
// // return std::make_tuple(
// // output_mask[0] ?
// torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
// // : undefined,
// // output_mask[1] ?
// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
// // : undefined,
// // output_mask[2] ?
// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
// // : undefined);
// }
std::tuple<at::Tensor, at::Tensor, at::Tensor>
LazyNativeFunctions::native_batch_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
std::array<bool, 3> output_mask) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out);
const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice();
torch::lazy::LazyTensorPtr null_tensor;
bool running_stats = running_mean && running_mean->defined();
CHECK_EQ(running_var && running_var->defined(), running_stats);
auto gradients = torch::lazy::native_batch_norm_backward(
torch::lazy::TryGetLtcTensor(grad_out),
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor,
running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor,
GetOrCreateLtcTensor(save_mean, device),
GetOrCreateLtcTensor(save_invstd, device), train, eps, output_mask);
at::Tensor undefined;
return std::make_tuple(
output_mask[0]
? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
: undefined,
output_mask[1]
? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
: undefined,
output_mask[2]
? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
: undefined);
}
at::Tensor
LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::permute(
// self_tensor, torch::lazy::ToI64Vector(dims)));
}
at::Tensor
LazyNativeFunctions::repeat(const at::Tensor& self, at::IntArrayRef repeats) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::repeat(
// torch::lazy::TryGetLtcTensor(self),
// torch::lazy::ToI64Vector(repeats)));
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims)));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self)));
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self)));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(
// lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self),
// dim));
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim));
}
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::transpose(
torch::lazy::TryGetLtcTensor(self), 0, 1));
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1));
}
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
// return torch::lazy::CreateAtenFromLtcTensor(
// lazy_tensor_aten_ops::unsqueeze(torch::lazy::TryGetLtcTensor(self),
// dim));
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim));
}
at::Tensor
@ -391,9 +368,10 @@ LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
lazy_tensor_aten_ops::view(self_tensor, torch::lazy::ToI64Vector(size)));
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
}
void InitializeAtenBindings() {}
} // namespace torch_lazy_tensors
} // namespace lazy
} // namespace torch

View File

@ -18,112 +18,9 @@
namespace torch {
namespace lazy {
namespace {
hash_t OperandHashes(
const OpList& operands, const hash_t& seed, const bool bakeInSizes) {
hash_t hash = seed;
for (auto& operand : operands) {
if (!operand) {
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
continue;
}
auto operand_hash =
bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
hash = HashCombine(hash, operand_hash);
}
return hash;
}
hash_t GetOpHash(
OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) {
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
return HashCombine(h, hash_seed);
}
} // namespace
MlirNode::MlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
hash_t hash_seed)
: Node(
op, num_outputs,
/* node_hash */ HashCombine(op.hash(), hash_seed),
/* dag_hash */
[&](bool bakeInSizes) -> hash_t {
return OperandHashes(
operands, HashCombine(op.hash(), hash_seed), bakeInSizes);
}),
shapes_(std::move(shapes)) {
for (auto& operand : operands) {
// Ideally, optional operands should be filtered by the leaf node classes,
// but it's just much easier to do it here.
if (!operand) {
continue;
}
AddOperand(operand.node, operand.index);
}
}
MlirNode::MlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
size_t num_outputs, hash_t hash_seed)
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
shapes_.push_back(GetOpShape(shape_fn));
}
MlirNode::MlirNode(
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
: MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
void MlirNode::SetShapeDeferred(const std::function<Shape()>& shape_fn) {
shapes_.push_back(GetOpShape(shape_fn));
}
MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
: Node(op, num_outputs, [&](bool bakeInSizes) -> hash_t {
return GetOpHash(op, shape, hash_seed, bakeInSizes);
}) {
shapes_.push_back(std::move(shape));
}
using ShapeCache = Cache<hash_t, Shape, HashReducer>;
constexpr const int torch_lazy_shape_cache_size = 4096;
ShapeCache* GetShapeCache() {
static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size);
return cache;
}
Shape MlirNode::GetOpShape(const std::function<Shape()>& shape_fn) const {
ShapeCache* shape_cache = GetShapeCache();
auto shape = shape_cache->Get(hash());
if (shape == nullptr) {
shape = shape_cache->Add(hash(), std::make_shared<Shape>(shape_fn()));
}
return *shape;
}
c10::ArrayRef<Shape> MlirNode::shapes() const { return shapes_; }
const Shape& MlirNode::shape(size_t output_index) const {
return shapes_.at(output_index);
}
const std::vector<Output>& MlirNode::operands() const {
return operands_as_outputs_;
}
const Output& MlirNode::operand(size_t i) const {
return operands_as_outputs_.at(i);
}
void MlirNode::AddOperand(NodePtr node, size_t index) {
CHECK_LT(index, node->num_outputs());
operands_.push_back(std::move(node));
operands_as_outputs_.emplace_back(operands_.back().get(), index);
TorchMlirOpVector
TorchMlirNode::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
return {};
}
} // namespace lazy

View File

@ -7,76 +7,33 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.h
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node.h
//===----------------------------------------------------------------------===//
#pragma once
#include <ATen/core/interned_strings.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/shape.h>
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "aten_eager_fallback.h"
#include "mlir_lowering_context.h"
namespace torch {
namespace lazy {
typedef std::vector<NodePtr> MlirOpVector;
typedef NodePtr MlirFunction;
class TORCH_API MlirNode : public torch::lazy::Node {
typedef std::vector<torch::jit::Value*> TorchMlirOpVector;
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
class TORCH_API TorchMlirNode : public torch::lazy::Node {
public:
MlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
using torch::lazy::Node::Node;
// Same as the constructor above, but the shape is generated by a function,
// only if needed (shape cache miss).
MlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);
// The shape is set later.
MlirNode(
OpKind op, OpList operands, size_t num_outputs = 1,
hash_t hash_seed = kHashSeed);
void SetShapeDeferred(const std::function<Shape()>& shape_fn);
// Contructor used to create leaf nodes.
MlirNode(
OpKind op, Shape shape, size_t num_outputs = 1,
hash_t hash_seed = kHashSeed);
Shape GetOpShape(const std::function<Shape()>& shape_fn) const;
// Retrieves the full shape of the IR Node.
c10::ArrayRef<Shape> shapes() const override;
// Retrieves the shape of the output at a given index.
const Shape& shape(size_t output_index = 0) const override;
const std::vector<Output>& operands() const override;
const Output& operand(size_t i) const override;
virtual MlirOpVector
Lower(MlirFunction function, MlirLoweringContext* loctx) const = 0;
private:
// Adds node's index output number as operand.
void AddOperand(NodePtr node, size_t index = 0);
std::vector<Shape> shapes_;
// A node holds a real reference to its operands.
std::vector<NodePtr> operands_;
// Outputs do not hold references on the nodes, and neither do the uses, since
// otherwise we get into circular reference counting.
std::vector<Output> operands_as_outputs_;
virtual TorchMlirOpVector
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
};
} // namespace lazy

View File

@ -1,242 +0,0 @@
//===- tensor_aten_ops.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.cpp
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <functional>
#include <ATen/InferSize.h>
#include <c10/util/Optional.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/ir_util.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
#include <torch/csrc/lazy/core/view_ops/permute.h>
#include <torch/csrc/lazy/core/view_ops/view.h>
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
#include <torch/csrc/lazy/ts_backend/ops/expand.h>
#include "tensor_aten_ops.h"
namespace torch_lazy_tensors {
namespace lazy_tensor_aten_ops {
namespace {
// to enable operator+-*/ for Value
using namespace torch::lazy;
torch::lazy::Value MaybeExpand(
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
if (input.shape().sizes() == target_shape.sizes()) {
return input;
}
return torch::lazy::MakeNode<torch::lazy::Expand>(
input, target_shape.sizes().vec(),
/*is_scalar_expand=*/false);
}
std::vector<int64_t> GetExpandDimensions(
const torch::lazy::Shape& shape, std::vector<int64_t> dimensions) {
CHECK_GE(dimensions.size(), shape.dim()) << shape;
int64_t base = dimensions.size() - shape.dim();
for (size_t i = 0; i < shape.dim(); ++i) {
if (dimensions[base + i] == -1) {
dimensions[base + i] = shape.size(i);
}
}
return dimensions;
}
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
torch::lazy::Shape
BatchNormFeaturesShape(const torch::lazy::LazyTensorPtr& input) {
CHECK(input);
auto input_shape = input->shape().Get();
return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]);
}
// Returns the IR for the given input or the provided default value broadcasted
// to the default shape, if the input is undefined.
torch::lazy::Value GetIrValueOrDefault(
const torch::lazy::LazyTensorPtr& input, const at::Scalar& default_value,
const torch::lazy::Shape& default_shape,
const torch::lazy::BackendDevice& device) {
return input ? input->GetIrValue()
: torch::lazy::LazyGraphExecutor::Get()
->GetIrValueForExpandedScalar(
default_value, default_shape, device);
}
torch::lazy::ViewInfo CreateAsStridedViewInfo(
const torch::lazy::Shape& input_shape, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
torch::lazy::Shape result_shape =
torch::lazy::Shape(input_shape.scalar_type(), size);
torch::lazy::AsStridedInfo as_strided_info;
as_strided_info.stride = std::move(stride);
if (storage_offset) {
as_strided_info.offset = *storage_offset;
}
return torch::lazy::ViewInfo(
torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape),
input_shape, std::move(as_strided_info));
}
} // namespace
//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
torch::lazy::LazyTensorPtr as_strided(
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
auto input_shape = input->shape();
return input->CreateViewTensor(CreateAsStridedViewInfo(
input_shape, std::move(size), std::move(stride), storage_offset));
}
void as_strided_(
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
if (input->data()->view == nullptr) {
input->SetIrValue(torch::lazy::MakeNode<torch::lazy::AsStrided>(
input->GetIrValue(), std::move(size), std::move(stride),
storage_offset.value_or(0)));
} else {
auto input_shape = input->shape();
input->SetSubView(CreateAsStridedViewInfo(
input_shape, std::move(size), std::move(stride), storage_offset));
}
}
torch::lazy::LazyTensorPtr
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size) {
auto input_shape = input->shape();
return torch::lazy::LazyTensor::Create(
torch::lazy::MakeNode<torch::lazy::Expand>(
input->GetIrValue(),
GetExpandDimensions(input_shape.Get(), std::move(size)),
/*is_scalar_expand=*/false),
input->GetDevice());
}
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) {
torch::lazy::Value constant =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
value, input->shape(), input->GetDevice());
input->SetInPlaceIrValue(std::move(constant));
}
torch::lazy::LazyTensorPtr narrow(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t length) {
auto input_shape = input->shape();
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
torch::lazy::Shape narrow_shape = input_shape;
narrow_shape.set_size(dim, length);
torch::lazy::ViewInfo::Type view_type =
(input_shape.Get().numel() == narrow_shape.numel())
? torch::lazy::ViewInfo::Type::kReshape
: torch::lazy::ViewInfo::Type::kNarrow;
torch::lazy::ViewInfo view_info(
view_type, std::move(narrow_shape), input_shape);
view_info.indices[dim] =
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
return input->CreateViewTensor(std::move(view_info));
}
torch::lazy::LazyTensorPtr
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims) {
auto input_shape = input->shape();
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape,
torch::lazy::GetCanonicalDimensionIndices(dims, input_shape.Get().dim()));
return input->CreateViewTensor(std::move(view_info));
}
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
if (input->GetDevice() == src->GetDevice()) {
torch::lazy::Value copy_value;
if (input->dtype() == src->dtype()) {
copy_value = src->GetIrValue();
} else {
copy_value = torch::lazy::MakeNode<torch::lazy::Cast>(
src->GetIrValue(), input->dtype(), src->dtype());
}
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
} else {
auto input_shape = input->shape();
at::Tensor src_tensor = src->ToTensor(/*detached=*/true);
if (src_tensor.sizes() != input_shape.Get().sizes()) {
src_tensor = src_tensor.expand(input_shape.Get().sizes().vec());
}
input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false);
}
}
torch::lazy::LazyTensorPtr slice(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t end, int64_t step) {
auto input_shape = input->shape();
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
start =
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end);
// PyTorch allows tensor[-1:0] to return a 0-dim tensor.
if (start > end) {
end = start;
}
step = std::min(step, end - start);
torch::lazy::SelectInfo select = {dim, start, end, step};
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kSelect, input_shape, std::move(select));
return input->CreateViewTensor(std::move(view_info));
}
torch::lazy::LazyTensorPtr
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
auto input_shape = input->shape();
auto permute_dims = torch::lazy::MakeTransposePermutation(
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
return input->CreateViewTensor(std::move(view_info));
}
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) {
auto input_shape = input->shape();
auto permute_dims = torch::lazy::MakeTransposePermutation(
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
return input->ModifyCurrentView(std::move(view_info));
}
torch::lazy::LazyTensorPtr view(
const torch::lazy::LazyTensorPtr& input,
c10::ArrayRef<int64_t> output_size) {
auto input_shape = input->shape().Get();
torch::lazy::Shape shape = torch::lazy::Shape(
input_shape.scalar_type(),
at::infer_size(output_size, input_shape.numel()));
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
return input->CreateViewTensor(std::move(view_info));
}
} // namespace lazy_tensor_aten_ops
} // namespace torch_lazy_tensors

View File

@ -1,79 +0,0 @@
//===- tensor_aten_ops.h --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.h
//===----------------------------------------------------------------------===//
#pragma once
#include <torch/csrc/lazy/core/tensor.h>
namespace torch_lazy_tensors {
namespace lazy_tensor_aten_ops {
//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
// Takes a slice from the input as R1 at the specified offset and reshapes it
// into the provided size.
torch::lazy::LazyTensorPtr as_strided(
const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
// In-place version of the method above.
void as_strided_(
torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset);
torch::lazy::LazyTensorPtr
expand(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> size);
// Fills the input with the given value.
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value);
// Returns a new tensor that is a narrowed view of the input in the given
// dimension.
torch::lazy::LazyTensorPtr narrow(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t length);
// Permute the dimensions of this tensor according to the given permutation.
torch::lazy::LazyTensorPtr
permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef<int64_t> dims);
// Repeats the input tensor along each dimension by the given number of
// repeats.
torch::lazy::LazyTensorPtr
repeat(const torch::lazy::LazyTensorPtr& input, std::vector<int64_t> repeats);
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src);
torch::lazy::LazyTensorPtr slice(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t end, int64_t step);
std::tuple<
torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr,
torch::lazy::LazyTensorPtr>
svd(const torch::lazy::LazyTensorPtr& input, bool some, bool compute_uv);
// Swap given dimensions of the input.
torch::lazy::LazyTensorPtr
transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
// In-place version of the method above.
void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1);
// Like reshape, but it returns a view into the original tensor.
torch::lazy::LazyTensorPtr view(
const torch::lazy::LazyTensorPtr& input,
c10::ArrayRef<int64_t> output_size);
} // namespace lazy_tensor_aten_ops
} // namespace torch_lazy_tensors