mirror of https://github.com/llvm/torch-mlir
Fix LTC Decoupling (#815)
* Initial changes * Fix up native functions * Further fix decoupling * Remove unnecessary ops * Formatting and copyright banners: * Add pytorch submodulepull/1125/head
parent
cca9fe126e
commit
1bde00c73d
|
@ -23,15 +23,10 @@ __pycache__
|
|||
# Bazel
|
||||
bazel-*
|
||||
|
||||
# Libraries
|
||||
*.so
|
||||
*.a
|
||||
|
||||
# Autogenerated files
|
||||
/generated_native_functions.yaml
|
||||
/generated_backend.hash
|
||||
/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h
|
||||
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp
|
||||
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h
|
||||
/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp
|
||||
/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp
|
||||
/python/torch_mlir/csrc/base_lazy_backend/generated
|
||||
|
||||
# Example backend
|
||||
examples/ltc_backend/ltc_backend/_EXAMPLE_MLIR_BACKEND.cpython-37m-x86_64-linux-gnu.so
|
||||
|
|
|
@ -12,20 +12,20 @@ from textwrap import dedent
|
|||
import yaml
|
||||
|
||||
TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
|
||||
TORCH_DIR = TORCH_MLIR_DIR.parent.joinpath("pytorch")
|
||||
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")
|
||||
|
||||
sys.path.append(str(TORCH_DIR.joinpath("tools")))
|
||||
sys.path.append(str(TORCH_DIR))
|
||||
|
||||
# PyTorch's LTC backend autogen script
|
||||
import codegen.dest.lazy_ir
|
||||
import codegen.gen_lazy_tensor
|
||||
from codegen.api.lazy import LazyIrSchema
|
||||
from codegen.gen import get_grouped_native_functions, parse_native_yaml
|
||||
from codegen.model import NativeFunctionsGroup
|
||||
import torchgen.dest.lazy_ir
|
||||
import torchgen.gen_lazy_tensor
|
||||
from torchgen.api.lazy import LazyIrSchema
|
||||
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
|
||||
from torchgen.model import NativeFunctionsGroup
|
||||
|
||||
|
||||
def isOptionalCType(arg):
|
||||
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"
|
||||
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"
|
||||
|
||||
|
||||
def generate_native_functions(
|
||||
|
@ -33,11 +33,11 @@ def generate_native_functions(
|
|||
):
|
||||
print("Generating Native Functions Yaml")
|
||||
|
||||
native_yaml_path = TORCH_DIR.joinpath(
|
||||
"aten", "src", "ATen", "native", "native_functions.yaml"
|
||||
)
|
||||
native_path = TORCH_DIR.joinpath("aten", "src", "ATen", "native")
|
||||
native_yaml_path = native_path.joinpath("native_functions.yaml")
|
||||
tags_yaml_path = native_path.joinpath("tags.yaml")
|
||||
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path)
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||
native_functions = parsed_yaml.native_functions
|
||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
|
||||
|
@ -57,6 +57,9 @@ def generate_native_functions(
|
|||
# primarily view ops
|
||||
supported = config.get("supported", [])
|
||||
|
||||
# List of non-native ops to do IR codegen for
|
||||
non_native = config.get("non_native", [])
|
||||
|
||||
if which("rg") is not None: # use ripgrep if available as its much faster
|
||||
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
|
||||
else:
|
||||
|
@ -105,6 +108,7 @@ def generate_native_functions(
|
|||
"cpp_namespace": "torch::lazy",
|
||||
"full_codegen": opnames,
|
||||
"supported": sorted(supported_ops),
|
||||
"non_native": non_native,
|
||||
},
|
||||
f,
|
||||
default_flow_style=False,
|
||||
|
@ -123,13 +127,13 @@ def generate_native_functions(
|
|||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
|
||||
class GenMlirLazyIr(torchgen.dest.GenLazyIR):
|
||||
|
||||
def lowering_function(self, f):
|
||||
func = (
|
||||
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
)
|
||||
schema = LazyIrSchema(func)
|
||||
def lowering_function(self, schema, declaration_only=True):
|
||||
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"
|
||||
|
||||
if declaration_only:
|
||||
return f"{signature};"
|
||||
|
||||
emplace_arguments = []
|
||||
for arg in schema.positional_args:
|
||||
|
@ -149,7 +153,7 @@ class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
|
|||
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])
|
||||
|
||||
return f"""
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override {{
|
||||
{signature} {{
|
||||
PRINT_FUNCTION();
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
|
@ -159,7 +163,7 @@ class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
|
|||
{emplace_arguments_str}
|
||||
{emplace_kwarguments}
|
||||
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
|
||||
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
|
||||
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
|
||||
|
||||
return {schema.aten_name}_out;
|
||||
}}
|
||||
|
@ -178,13 +182,13 @@ def generate_backend(
|
|||
def gen_fallback_code(*args, **kwargs):
|
||||
return ""
|
||||
|
||||
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
|
||||
torchgen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
|
||||
|
||||
codegen.gen_lazy_tensor.run_gen_lazy_tensor(
|
||||
torchgen.gen_lazy_tensor.run_gen_lazy_tensor(
|
||||
backend_name="TorchMlir",
|
||||
aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")),
|
||||
source_yaml=str(source_yaml),
|
||||
output_dir=str(backend_path),
|
||||
output_dir=str(backend_path.joinpath("generated")),
|
||||
dry_run=False,
|
||||
impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")),
|
||||
node_base="torch::lazy::TorchMlirNode",
|
||||
|
@ -192,7 +196,7 @@ def generate_backend(
|
|||
tensor_class="torch::lazy::LazyTensor",
|
||||
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
|
||||
shape_inference_hdr=str(backend_path.joinpath("LazyShapeInference.h")),
|
||||
lazy_ir_cls=MlirLazyIr,
|
||||
lazy_ir_generator=GenMlirLazyIr,
|
||||
)
|
||||
|
||||
# Remove lazy_tensor_core imports
|
||||
|
@ -201,7 +205,7 @@ def generate_backend(
|
|||
"sed",
|
||||
"-i",
|
||||
"/lazy_tensor_core/d",
|
||||
str(backend_path.joinpath("LazyNativeFunctions.cpp")),
|
||||
str(backend_path.joinpath("generated", "LazyNativeFunctions.cpp")),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -240,14 +244,14 @@ def generate_backend(
|
|||
- shape_inference_defs
|
||||
)
|
||||
if missing_defs:
|
||||
backend_path.joinpath("GenLazyShapeInference.cpp").write_text(
|
||||
backend_path.joinpath("generated", "GenLazyShapeInference.cpp").write_text(
|
||||
dedent(
|
||||
"""
|
||||
// This file contains autogenerated Lazy Shape Inference placeholders
|
||||
// for ops that dont have a corresponding structured kernel or shape definition
|
||||
|
||||
#include "LazyShapeInference.h"
|
||||
#include "../utils/exception.h"
|
||||
#include "../LazyShapeInference.h"
|
||||
#include "../../utils/exception.h"
|
||||
namespace torch {{
|
||||
namespace lazy {{
|
||||
{}
|
||||
|
|
|
@ -38,8 +38,7 @@ supported:
|
|||
- empty
|
||||
- expand
|
||||
- fill_
|
||||
- native_batch_norm
|
||||
# - native_batch_norm_backward
|
||||
- native_batch_norm_backward
|
||||
- permute
|
||||
- squeeze
|
||||
- t
|
||||
|
@ -50,4 +49,39 @@ additional_ops:
|
|||
# Additional ops to support that are not supported by Torch-MLIR explicitly
|
||||
- _copy_from
|
||||
- _copy_from_and_resize
|
||||
- native_batch_norm_backward
|
||||
# - native_batch_norm_backward
|
||||
|
||||
# List of non native ops that we only want to do IR node class generation for
|
||||
non_native:
|
||||
- func: device_data(std::shared_ptr<BackendData> data) -> Tensor
|
||||
opkind: ltc_device_data
|
||||
cache_shape: false
|
||||
- func: scalar(at::Scalar value, at::ScalarType type) -> Tensor
|
||||
opkind: at::prim::Constant
|
||||
cache_shape: false
|
||||
- func: expand(Tensor input, std::vector<int64_t> size, bool is_scalar_expand) -> Tensor
|
||||
- func: view(Tensor input, std::vector<int64_t> output_size) -> Tensor
|
||||
cache_shape: false
|
||||
- func: cast(Tensor input, at::ScalarType dtype, optional<at::ScalarType> stype) -> Tensor
|
||||
opkind: ltc_cast
|
||||
cache_shape: false
|
||||
|
||||
# View ops only required until proper functionalization pass is introduced into LTC
|
||||
- func: as_strided_view_update(Tensor target, Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
|
||||
opkind: ltc_as_strided_view_update
|
||||
- func: as_strided(Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
|
||||
- func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
|
||||
opkind: ltc_diagonal_view_update
|
||||
cache_shape: false
|
||||
- func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
|
||||
- func: narrow_view_update(Tensor input, Tensor source, std::vector<int64_t> base_indices) -> Tensor
|
||||
opkind: ltc_narrow_view_update
|
||||
- func: narrow(Tensor input, std::vector<int64_t> base_indices, std::vector<int64_t> sizes) -> Tensor
|
||||
- func: permute(Tensor input, std::vector<int64_t> dims) -> Tensor
|
||||
- func: resize(Tensor input, std::vector<int64_t> size) -> Tensor
|
||||
- func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
|
||||
opkind: ltc_select_view_update
|
||||
cache_shape: false
|
||||
- func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
|
||||
- func: squeeze(Tensor input, int dim) -> Tensor
|
||||
- func: unsqueeze(Tensor input, int dim) -> Tensor
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -12,7 +12,7 @@
|
|||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||
#include <torch/csrc/lazy/core/shape.h>
|
||||
|
||||
#include <torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <torch_mlir/csrc/utils/debug.h>
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 9f3d6a00a76c567d7c046eabc60ae7a578f7bbde
|
Binary file not shown.
|
@ -21,14 +21,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
|||
|
||||
add_library(torch_mlir_ltc_backend SHARED
|
||||
base_lazy_backend/backend_impl.cpp
|
||||
base_lazy_backend/LazyNativeFunctions.cpp
|
||||
base_lazy_backend/generated/LazyNativeFunctions.cpp
|
||||
base_lazy_backend/generated/GenLazyShapeInference.cpp
|
||||
base_lazy_backend/generated/RegisterLazy.cpp
|
||||
base_lazy_backend/LazyShapeInference.cpp
|
||||
base_lazy_backend/GenLazyShapeInference.cpp
|
||||
base_lazy_backend/mlir_lowering_context.cpp
|
||||
base_lazy_backend/mlir_native_functions.cpp
|
||||
base_lazy_backend/mlir_node.cpp
|
||||
base_lazy_backend/mlir_node_lowering.cpp
|
||||
base_lazy_backend/RegisterLazy.cpp
|
||||
)
|
||||
|
||||
add_dependencies(torch_mlir_ltc_backend
|
||||
|
|
|
@ -13,5 +13,25 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
std::vector<Shape> compute_shape_native_batch_norm(
|
||||
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var, bool training,
|
||||
double momentum, double eps) {
|
||||
std::vector<Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
if (running_mean.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
|
||||
if (running_var.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_var.value().scalar_type(), running_var.value().sizes().vec());
|
||||
}
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -23,6 +23,7 @@ namespace lazy {
|
|||
// clang-format off
|
||||
|
||||
TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
|
||||
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
|
||||
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
|
||||
|
@ -37,6 +38,8 @@ TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, con
|
|||
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
|
||||
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
|
||||
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
|
||||
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
|
||||
|
@ -63,6 +66,7 @@ TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::op
|
|||
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
|
||||
TORCH_API std::vector<Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
|
||||
|
@ -82,6 +86,9 @@ TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::opt
|
|||
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
|
||||
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
|
||||
TORCH_API std::vector<Shape> compute_shape_zero_(at::Tensor & self);
|
||||
|
||||
TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "../utils/debug.h"
|
||||
#include "../utils/exception.h"
|
||||
#include "backend_impl.h"
|
||||
#include "ir_builder.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
|
||||
namespace torch {
|
||||
|
@ -72,6 +73,15 @@ TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
|
|||
* */
|
||||
void TorchMlirBackendImpl::PrepareToExit() const {}
|
||||
|
||||
/**
|
||||
* IR Tracing
|
||||
* */
|
||||
|
||||
const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const {
|
||||
static const IrBuilder* builder = new TorchMlirIrBuilder();
|
||||
return builder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Data Transfer
|
||||
* */
|
||||
|
@ -95,6 +105,16 @@ BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
|
|||
return std::make_shared<TorchMlirBackendData>(device, shape);
|
||||
}
|
||||
|
||||
BackendDataPtr
|
||||
TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
|
||||
PRINT_FUNCTION();
|
||||
auto* device_data_node = dynamic_cast<DeviceData*>(node);
|
||||
if (!device_data_node) {
|
||||
return nullptr;
|
||||
}
|
||||
return device_data_node->data;
|
||||
}
|
||||
|
||||
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
||||
const BackendDataPtr data,
|
||||
c10::optional<at::ScalarType> logical_scalar_type) const {
|
||||
|
|
|
@ -65,6 +65,12 @@ public:
|
|||
* */
|
||||
virtual void PrepareToExit() const override;
|
||||
|
||||
/**
|
||||
* IR Tracing
|
||||
* */
|
||||
|
||||
const IrBuilder* GetIrBuilder() const override;
|
||||
|
||||
/**
|
||||
* Configuration
|
||||
* */
|
||||
|
@ -84,6 +90,10 @@ public:
|
|||
virtual BackendDataPtr CreateDataPlaceholder(
|
||||
const BackendDevice& device, const Shape& shape) const override;
|
||||
|
||||
// Gets backend data if the node is a device data node. Otherwise returns
|
||||
// nullptr.
|
||||
virtual BackendDataPtr GetComputationDataFromNode(Node*) const override;
|
||||
|
||||
virtual at::Tensor MakeTensorFromComputationData(
|
||||
const BackendDataPtr data,
|
||||
c10::optional<at::ScalarType> logical_scalar_type) const override;
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
//===- dynamic_ir.cpp -----------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed)
|
||||
: TorchMlirNode(
|
||||
op, operands, /*num_outputs=*/1,
|
||||
/* hash_seed */ HashCombine(op.hash(), hash_seed)) {}
|
||||
|
||||
std::string DimensionNode::ToString() const { return "DimensionNode"; }
|
||||
|
||||
SizeNode::SizeNode(Value input, size_t dim)
|
||||
: DimensionNode(
|
||||
OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
|
||||
MHash(dim)),
|
||||
dim_(dim){};
|
||||
|
||||
int64_t SizeNode::getStaticValue() const {
|
||||
return dynamic_cast<const TorchMlirNode*>(operand(0).node)
|
||||
->shape(0)
|
||||
.size(dim_);
|
||||
}
|
||||
|
||||
std::string SizeNode::ToString() const { return "SizeNode"; }
|
||||
|
||||
SizeAdd::SizeAdd(Value a, Value b)
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){};
|
||||
|
||||
int64_t SizeAdd::getStaticValue() const {
|
||||
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() +
|
||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
|
||||
}
|
||||
|
||||
std::string SizeAdd::ToString() const { return "SizeAdd"; }
|
||||
|
||||
SizeMul::SizeMul(Value a, Value b)
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){};
|
||||
|
||||
int64_t SizeMul::getStaticValue() const {
|
||||
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() *
|
||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
|
||||
}
|
||||
|
||||
std::string SizeMul::ToString() const { return "SizeMul"; }
|
||||
|
||||
SizeDiv::SizeDiv(Value a, Value b)
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){};
|
||||
|
||||
int64_t SizeDiv::getStaticValue() const {
|
||||
TORCH_CHECK(
|
||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue() !=
|
||||
0,
|
||||
"Can't divide a dimension by zero");
|
||||
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() /
|
||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
|
||||
}
|
||||
|
||||
std::string SizeDiv::ToString() const { return "SizeDiv"; }
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,99 @@
|
|||
//===- dynamic_ir.h -------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/symbol.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir_node.h"
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Flags.h>
|
||||
#include <torch/csrc/lazy/core/hash.h>
|
||||
#include <torch/csrc/lazy/core/ir.h>
|
||||
#include <torch/csrc/lazy/core/ir_metadata.h>
|
||||
|
||||
C10_DECLARE_bool(ltc_enable_dynamic_shapes);
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
/**
|
||||
* The goal of "dynamic" Nodes is to patch a hole in our tracing.
|
||||
* Previously, if a user called `sizes` on a Tensor, it would leak out
|
||||
* of our tracing system, as `sizes` returns a torch.Size or an int. To
|
||||
* prevent this from happening, we introduce DimensionNode, a new type
|
||||
* of Node that abstracts the operation of getting the dimensions of a
|
||||
* Tensor.
|
||||
*
|
||||
* Consider the following example:
|
||||
* ```
|
||||
* numel = x.shape()[0] * x.shape()[1]
|
||||
* ```
|
||||
*
|
||||
* Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode),
|
||||
* and the multiplication of the two SizeNodes will be represented by
|
||||
* a SizeMul (also a subclass of DimensionNode). Through this, we can
|
||||
* prevent `numel` from being represented as a Python int and thus
|
||||
* burned into the Graph.
|
||||
*/
|
||||
|
||||
class TORCH_API DimensionNode : public lazy::TorchMlirNode {
|
||||
public:
|
||||
DimensionNode(OpKind op, OpList operands, hash_t hash_seed = kHashSeed);
|
||||
bool isDynamic() { return false; }
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
virtual int64_t getStaticValue() const = 0;
|
||||
};
|
||||
|
||||
// Represents the result of calling `size` on a Tensor
|
||||
class TORCH_API SizeNode : public DimensionNode {
|
||||
public:
|
||||
SizeNode(Value input, size_t dim);
|
||||
int64_t getStaticValue() const override;
|
||||
std::string ToString() const override;
|
||||
size_t dim_ = 0;
|
||||
};
|
||||
|
||||
class TORCH_API SizeAdd : public DimensionNode {
|
||||
public:
|
||||
SizeAdd(Value a, Value b);
|
||||
int64_t getStaticValue() const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
class TORCH_API SizeMul : public DimensionNode {
|
||||
public:
|
||||
SizeMul(Value a, Value b);
|
||||
int64_t getStaticValue() const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
class TORCH_API SizeDiv : public DimensionNode {
|
||||
public:
|
||||
SizeDiv(Value a, Value b);
|
||||
int64_t getStaticValue() const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,65 @@
|
|||
//===- ir_builder.h -------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ir_builder.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/ir.h>
|
||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
#include <torch/csrc/lazy/core/shape_inference.h>
|
||||
|
||||
#include "dynamic_ir.h"
|
||||
#include "generated/LazyNonNativeIr.h"
|
||||
#include "mlir_node.h"
|
||||
#include "ops/generic.h"
|
||||
|
||||
// This file contains the TorchMlir IrBuilder
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
// clang-format off
|
||||
|
||||
struct TorchMlirIrBuilder : IrBuilder {
|
||||
NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) const override { return MakeNode<DeviceData>(data); }
|
||||
NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { return MakeNode<Scalar>(value, type); }
|
||||
NodePtr MakeExpand(const Value& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand) const override { return MakeNode<Expand>(input0, size, is_scalar_expand); }
|
||||
NodePtr MakeView(const Value& input0, const std::vector<int64_t>& output_size) const override { return MakeNode<View>(input0, output_size); }
|
||||
NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype = c10::nullopt) const override { return MakeNode<Cast>(input0, dtype, stype); }
|
||||
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TensorList>(inputs); }
|
||||
NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const override { return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed); }
|
||||
|
||||
// view ops
|
||||
NodePtr MakeAsStridedViewUpdate(const Value& input0, const Value& input1, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) const override { return MakeNode<AsStridedViewUpdate>(input0, input1, size, stride, storage_offset); }
|
||||
NodePtr MakeAsStrided(const Value& input0, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) const override { return MakeNode<AsStrided>(input0, size, stride, storage_offset); }
|
||||
NodePtr MakeDiagonalViewUpdate(const Value& input0, const Value& input1, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { return MakeNode<DiagonalViewUpdate>(input0, input1, offset, dim1, dim2); }
|
||||
NodePtr MakeDiagonal(const Value& input0, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { return MakeNode<Diagonal>(input0, offset, dim1, dim2); }
|
||||
NodePtr MakeNarrowViewUpdate(const Value& input0, const Value& input1, const std::vector<int64_t>& base_indices) const override { return MakeNode<NarrowViewUpdate>(input0, input1, base_indices); }
|
||||
NodePtr MakeNarrow(const Value& input0, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes) const override { return MakeNode<Narrow>(input0, base_indices, sizes); }
|
||||
NodePtr MakePermute(const Value& input0, const std::vector<int64_t>& dims) const override { return MakeNode<Permute>(input0, dims); }
|
||||
NodePtr MakeResize(const Value& input0, const std::vector<int64_t>& size) const override { return MakeNode<Resize>(input0, size); }
|
||||
NodePtr MakeSelectViewUpdate(const Value& input0, const Value& input1, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { return MakeNode<SelectViewUpdate>(input0, input1, dim, start, end, stride); }
|
||||
NodePtr MakeSelect(const Value& input0, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { return MakeNode<Select>(input0, dim, start, end, stride); }
|
||||
NodePtr MakeSqueeze(const Value& input0, const int& dim) const override { return MakeNode<Squeeze>(input0, dim); }
|
||||
NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const override { return MakeNode<Unsqueeze>(input0, dim); }
|
||||
|
||||
// dynamic ir nodes
|
||||
NodePtr MakeSizeNode(const Value& input, size_t dim) const override { return MakeNode<SizeNode>(input, dim); }
|
||||
NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { return MakeNode<SizeAdd>(a, b); }
|
||||
NodePtr MakeSizeMul(const Value& a, const Value& b) const override { return MakeNode<SizeMul>(a, b); }
|
||||
NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { return MakeNode<SizeDiv>(a, b); }
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -148,12 +148,13 @@ void TorchMlirLoweringContext::AssignOutputOp(
|
|||
const Output& output, torch::jit::Value* op) {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
auto torch_mlir_node =
|
||||
NodeCast<TorchMlirNode>(output.node, output.node->op());
|
||||
if (!torch_mlir_node->getPythonStacktrace().empty()) {
|
||||
op->node()->s_(
|
||||
c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace());
|
||||
}
|
||||
// TODO (antoniojkim): Do we need this?
|
||||
// auto torch_mlir_node =
|
||||
// NodeCast<TorchMlirNode>(output.node, output.node->op());
|
||||
// if (!torch_mlir_node->getPythonStacktrace().empty()) {
|
||||
// op->node()->s_(
|
||||
// c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace());
|
||||
// }
|
||||
emitted_outputs_[output] = std::move(op);
|
||||
}
|
||||
|
||||
|
@ -301,7 +302,7 @@ unsigned TorchMlirComputation::num_results() const { return num_results_; }
|
|||
|
||||
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }
|
||||
|
||||
std::string TorchMlirComputation::to_string() const {
|
||||
const std::string TorchMlirComputation::to_string() const {
|
||||
// Since we use the C-MLIR API, we need to use a callback to print.
|
||||
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
|
||||
// user_data is a void ptr to some data structure of our choice -- in this
|
||||
|
|
|
@ -149,7 +149,7 @@ public:
|
|||
|
||||
MlirOperation func_op() const;
|
||||
|
||||
std::string to_string() const;
|
||||
const std::string to_string() const;
|
||||
|
||||
private:
|
||||
std::vector<std::string> parameter_names_;
|
||||
|
|
|
@ -10,16 +10,18 @@
|
|||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/Operators.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/result_type.h>
|
||||
#include <torch/csrc/lazy/core/helpers.h>
|
||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||
#include <torch/csrc/lazy/core/metrics.h>
|
||||
#include <torch/csrc/lazy/core/tensor_aten_ops.h>
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "ATen/MetaFunctions.h"
|
||||
|
@ -27,8 +29,8 @@
|
|||
|
||||
#include "../utils/exception.h"
|
||||
#include "../utils/sys_utils.h"
|
||||
#include "LazyNativeFunctions.h"
|
||||
#include "LazyShapeInference.h"
|
||||
#include "generated/LazyNativeFunctions.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
@ -110,6 +112,60 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
|
|||
return torch::lazy::atenDeviceToBackendDevice(*device);
|
||||
}
|
||||
|
||||
torch::lazy::Value MaybeExpand(
|
||||
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
|
||||
if (input.shape().sizes() == target_shape.sizes()) {
|
||||
return input;
|
||||
}
|
||||
return torch::lazy::MakeExpand(
|
||||
input, target_shape.sizes().vec(),
|
||||
/*is_scalar_expand=*/false);
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetExpandDimensions(
|
||||
const torch::lazy::Shape& shape, std::vector<int64_t> dimensions) {
|
||||
CHECK_GE(dimensions.size(), shape.dim()) << shape;
|
||||
int64_t base = dimensions.size() - shape.dim();
|
||||
for (size_t i = 0; i < shape.dim(); ++i) {
|
||||
if (dimensions[base + i] == -1) {
|
||||
dimensions[base + i] = shape.size(i);
|
||||
}
|
||||
}
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
|
||||
if (input->GetDevice() == src->GetDevice()) {
|
||||
torch::lazy::Value copy_value;
|
||||
if (input->dtype() == src->dtype()) {
|
||||
copy_value = src->GetIrValue();
|
||||
} else {
|
||||
copy_value = torch::lazy::MakeCast(
|
||||
src->GetIrValue(), input->dtype(), src->dtype());
|
||||
}
|
||||
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
|
||||
} else {
|
||||
auto input_shape = input->shape();
|
||||
at::Tensor src_tensor = src->ToTensor(/*detached=*/true);
|
||||
if (src_tensor.sizes() != input_shape.Get().sizes()) {
|
||||
src_tensor = src_tensor.expand(input_shape.Get().sizes().vec());
|
||||
}
|
||||
input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
torch::lazy::LazyTensorPtr create_view(
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
c10::ArrayRef<int64_t> output_size) {
|
||||
auto input_shape = input->shape().Get();
|
||||
torch::lazy::Shape shape = torch::lazy::Shape(
|
||||
input_shape.scalar_type(),
|
||||
at::infer_size(output_size, input_shape.numel()));
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
|
||||
return input->CreateViewTensor(std::move(view_info));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// at::Tensor LazyNativeFunctions::bernoulli(
|
||||
|
@ -209,7 +265,7 @@ at::Tensor LazyNativeFunctions::_copy_from(
|
|||
dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
|
||||
}
|
||||
} else {
|
||||
torch::lazy::copy_(dst_tensor, self_tensor);
|
||||
copy_(dst_tensor, self_tensor);
|
||||
auto* impl =
|
||||
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
|
||||
impl->set_tensor(dst_tensor);
|
||||
|
@ -260,115 +316,106 @@ at::Tensor LazyNativeFunctions::empty(
|
|||
at::Tensor LazyNativeFunctions::expand(
|
||||
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec()));
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
auto input_shape = self_tensor->shape();
|
||||
auto output = torch::lazy::LazyTensor::Create(
|
||||
torch::lazy::MakeExpand(
|
||||
self_tensor->GetIrValue(),
|
||||
GetExpandDimensions(input_shape.Get(), std::move(size.vec())),
|
||||
/*is_scalar_expand=*/false),
|
||||
self_tensor->GetDevice());
|
||||
output->SetStorage(self_tensor->Storage());
|
||||
return torch::lazy::CreateAtenFromLtcTensor(output);
|
||||
}
|
||||
|
||||
at::Tensor&
|
||||
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
torch::lazy::fill_(self_tensor, value);
|
||||
|
||||
torch::lazy::Value constant =
|
||||
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
|
||||
value, self_tensor->shape(), self_tensor->GetDevice());
|
||||
self_tensor->SetInPlaceIrValue(std::move(constant));
|
||||
return self;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
LazyNativeFunctions::native_batch_norm(
|
||||
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var, bool training,
|
||||
double momentum, double eps) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto input_tensor = torch::lazy::TryGetLtcTensor(input);
|
||||
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
|
||||
auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device);
|
||||
auto running_var_tensor = GetOrCreateLtcTensor(running_var, device);
|
||||
auto outputs = torch::lazy::native_batch_norm(
|
||||
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
|
||||
GetOrCreateLtcTensor(bias, device), running_mean_tensor,
|
||||
running_var_tensor, training, momentum, eps);
|
||||
return std::make_tuple(
|
||||
torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
|
||||
torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
|
||||
torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
LazyNativeFunctions::native_batch_norm_backward(
|
||||
const at::Tensor& grad_out, const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var,
|
||||
const c10::optional<at::Tensor>& save_mean,
|
||||
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
|
||||
std::array<bool, 3> output_mask) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out);
|
||||
const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice();
|
||||
torch::lazy::LazyTensorPtr null_tensor;
|
||||
bool running_stats = running_mean && running_mean->defined();
|
||||
CHECK_EQ(running_var && running_var->defined(), running_stats);
|
||||
auto gradients = torch::lazy::native_batch_norm_backward(
|
||||
torch::lazy::TryGetLtcTensor(grad_out),
|
||||
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
|
||||
running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor,
|
||||
running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor,
|
||||
GetOrCreateLtcTensor(save_mean, device),
|
||||
GetOrCreateLtcTensor(save_invstd, device), train, eps, output_mask);
|
||||
at::Tensor undefined;
|
||||
return std::make_tuple(
|
||||
output_mask[0]
|
||||
? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
|
||||
: undefined,
|
||||
output_mask[1]
|
||||
? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
|
||||
: undefined,
|
||||
output_mask[2]
|
||||
? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
|
||||
: undefined);
|
||||
}
|
||||
|
||||
at::Tensor
|
||||
LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
|
||||
auto input_shape = self_tensor->shape();
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kPermute, input_shape,
|
||||
torch::lazy::GetCanonicalDimensionIndices(
|
||||
torch::lazy::ToI64Vector(dims), input_shape.Get().dim()));
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims)));
|
||||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self)));
|
||||
return squeeze(self, -1);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
auto input_shape = self_tensor->shape();
|
||||
int64_t squeeze_dim = -1;
|
||||
if (dim != -1) {
|
||||
squeeze_dim =
|
||||
torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
|
||||
}
|
||||
auto output_dimensions =
|
||||
BuildSqueezedDimensions(input_shape.Get().sizes(), squeeze_dim);
|
||||
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim));
|
||||
create_view(self_tensor, output_dimensions));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
auto input_shape = self_tensor->shape();
|
||||
auto permute_dims = torch::lazy::MakeTransposePermutation(
|
||||
/*dim0=*/0, /*dim1=*/1, /*rank=*/input_shape.Get().dim());
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
|
||||
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1));
|
||||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
auto input_shape = self_tensor->shape();
|
||||
int64_t squeeze_dim =
|
||||
torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim() + 1);
|
||||
auto dimensions =
|
||||
BuildUnsqueezedDimensions(input_shape.Get().sizes(), squeeze_dim);
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim));
|
||||
create_view(self_tensor, dimensions));
|
||||
}
|
||||
|
||||
at::Tensor
|
||||
LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
auto input_shape = self_tensor->shape().Get();
|
||||
torch::lazy::Shape shape = torch::lazy::Shape(
|
||||
input_shape.scalar_type(),
|
||||
at::infer_size(torch::lazy::ToI64Vector(size), input_shape.numel()));
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
|
||||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
void InitializeAtenBindings() {}
|
||||
|
|
|
@ -10,18 +10,92 @@
|
|||
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <torch/csrc/lazy/core/cache.h>
|
||||
|
||||
#include "../utils/exception.h"
|
||||
#include "mlir_node.h"
|
||||
#include "../utils/exception.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
namespace {
|
||||
|
||||
hash_t OperandHashes(
|
||||
const OpList& operands, const c10::ArrayRef<Shape>& shapes,
|
||||
const hash_t& seed, bool bakeInSizes) {
|
||||
hash_t hash = seed;
|
||||
for (auto& operand : operands) {
|
||||
if (!operand) {
|
||||
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
||||
continue;
|
||||
}
|
||||
auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
|
||||
hash = HashCombine(hash, operand_hash);
|
||||
}
|
||||
for (auto& shape : shapes) {
|
||||
hash = HashCombine(hash, shape.hash(bakeInSizes));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TorchMlirNode::TorchMlirNode(
|
||||
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
|
||||
hash_t hash_seed)
|
||||
: Node(op, operands, std::move(shapes), num_outputs) {
|
||||
hash_seed = HashCombine(op.hash(), hash_seed);
|
||||
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
|
||||
dag_hash_ =
|
||||
(enableDynamicShape()
|
||||
? OperandHashes(operands, this->shapes(), hash_seed, false)
|
||||
: shape_hash_);
|
||||
}
|
||||
|
||||
TorchMlirNode::TorchMlirNode(
|
||||
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
||||
size_t num_outputs, hash_t hash_seed)
|
||||
: TorchMlirNode(
|
||||
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
|
||||
addComputedShape(shape_fn);
|
||||
}
|
||||
|
||||
TorchMlirNode::TorchMlirNode(
|
||||
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
|
||||
: TorchMlirNode(
|
||||
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
|
||||
|
||||
TorchMlirNode::TorchMlirNode(
|
||||
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
|
||||
: TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
|
||||
|
||||
hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
||||
|
||||
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
||||
|
||||
TorchMlirOpVector TorchMlirNode::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
TensorList::TensorList(OpList values)
|
||||
: TorchMlirNode(
|
||||
/*op=*/tensor_list_opkind,
|
||||
/*operands=*/values,
|
||||
/*shapes=*/std::vector<Shape>(),
|
||||
/*num_outputs=*/1,
|
||||
/*hash_seed=*/kHashSeed) {}
|
||||
|
||||
torch::lazy::TorchMlirOpVector TensorList::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::Value*> tensor_list;
|
||||
CHECK(!operands().empty());
|
||||
for (const torch::lazy::Output& operand : operands()) {
|
||||
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
||||
}
|
||||
auto graph = function->graph();
|
||||
auto listnode =
|
||||
graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list));
|
||||
return {listnode->output()};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -27,10 +27,63 @@ namespace lazy {
|
|||
|
||||
class TORCH_API TorchMlirNode : public torch::lazy::Node {
|
||||
public:
|
||||
using torch::lazy::Node::Node;
|
||||
TorchMlirNode(
|
||||
OpKind op, OpList operands, std::vector<Shape>&& shapes,
|
||||
size_t num_outputs, hash_t hash_seed = kHashSeed);
|
||||
|
||||
TorchMlirNode(
|
||||
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
||||
size_t num_outputs, hash_t hash_seed = kHashSeed);
|
||||
|
||||
TorchMlirNode(
|
||||
OpKind op, OpList operands, size_t num_outputs,
|
||||
hash_t hash_seed = kHashSeed);
|
||||
|
||||
TorchMlirNode(
|
||||
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed);
|
||||
|
||||
hash_t hash() const override;
|
||||
|
||||
hash_t shapeHash() const override;
|
||||
|
||||
virtual TorchMlirOpVector
|
||||
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
|
||||
|
||||
private:
|
||||
// The hash of the dag WITH size info. Used for shape caching
|
||||
hash_t shape_hash_;
|
||||
// The hash of the dag used to look up the compiled graph by a hash
|
||||
// in this case, we will use the dag hash WITHOUT size info if dynamic shape
|
||||
// is enabled and use the dag hash WITH size info otherwise.
|
||||
hash_t dag_hash_;
|
||||
};
|
||||
|
||||
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||
// import otherwise
|
||||
const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
|
||||
|
||||
// TensorList represents an at::TensorList which is a vector[Tensor] but is also
|
||||
// a first-class IValue and can be fed as a single input to a TS program. It is
|
||||
// much easier to handle TensorLists in Lazy Tensor code if they are represented
|
||||
// as a single Node so there can be more than one TensorList and more than one
|
||||
// Tensor side-by-side as operands to an op.
|
||||
//
|
||||
// Note: shape is undefined for TensorList. We assert in some places that
|
||||
// #shapes matches #outputs and this stems from
|
||||
// the fact that currently all IR nodes represent tensors (there is no
|
||||
// type system for this IR). Becuase of this, TensorList is a bit of a
|
||||
// hack.
|
||||
//
|
||||
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and
|
||||
// then implement it as NotImplemented for TensorList, also fixing the assertion
|
||||
// that would fail.
|
||||
struct TORCH_API TensorList : public TorchMlirNode {
|
||||
TensorList() = delete;
|
||||
TensorList(OpList values);
|
||||
|
||||
torch::lazy::TorchMlirOpVector Lower(
|
||||
TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir_node_lowering.h"
|
||||
#include "generated/LazyNonNativeIr.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
#include "mlir_node.h"
|
||||
|
||||
|
@ -21,26 +22,11 @@
|
|||
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||
#include <torch/csrc/lazy/core/helpers.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||
#include <torch/csrc/lazy/core/ops/utils.h>
|
||||
#include <torch/csrc/lazy/core/permutation_util.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/cast.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/device_data.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/ops/batch_norm_ops.h>
|
||||
#include <torch/csrc/lazy/core/ops/expand.h>
|
||||
#include <torch/csrc/lazy/core/ops/scalar.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/as_strided_view_update.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/narrow.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/narrow_view_update.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/permute.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/select.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/select_view_update.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/squeeze.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/unsqueeze.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/view.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
|
@ -195,16 +181,6 @@ public:
|
|||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
if (node->op().op == at::aten::native_batch_norm) {
|
||||
return LowerBatchNorm(
|
||||
torch::lazy::NodeCast<torch::lazy::NativeBatchNormForward>(
|
||||
node, torch::lazy::OpKind(at::aten::native_batch_norm)));
|
||||
}
|
||||
if (node->op().op == at::aten::native_batch_norm_backward) {
|
||||
return LowerBatchNormBackward(
|
||||
torch::lazy::NodeCast<torch::lazy::NativeBatchNormBackward>(
|
||||
node, torch::lazy::OpKind(at::aten::native_batch_norm_backward)));
|
||||
}
|
||||
if (node->op().op == at::aten::expand) {
|
||||
return LowerExpand(torch::lazy::NodeCast<torch::lazy::Expand>(
|
||||
node, torch::lazy::OpKind(at::aten::expand)));
|
||||
|
@ -237,14 +213,14 @@ public:
|
|||
const torch::lazy::DeviceData* device_data_node =
|
||||
torch::lazy::NodeCast<torch::lazy::DeviceData>(
|
||||
node, *torch::lazy::ltc_device_data);
|
||||
auto infoptr = device_data_node->data()->info();
|
||||
auto infoptr = device_data_node->data->info();
|
||||
auto deviceDataInfoPtr =
|
||||
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
||||
if (GRAPH_DUMP_ENABLED) {
|
||||
LOG(ERROR) << "Lowering device data node, tensor id "
|
||||
<< deviceDataInfoPtr->tensor_id << std::endl;
|
||||
}
|
||||
return {loctx()->GetParameter(device_data_node->data())};
|
||||
return {loctx()->GetParameter(device_data_node->data)};
|
||||
}
|
||||
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
|
@ -278,9 +254,9 @@ public:
|
|||
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size());
|
||||
arguments.emplace_back(node->stride());
|
||||
arguments.emplace_back(node->storage_offset());
|
||||
arguments.emplace_back(node->size);
|
||||
arguments.emplace_back(node->stride);
|
||||
arguments.emplace_back(node->storage_offset);
|
||||
TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
return {GenerateClone(as_strided_out.front())};
|
||||
|
@ -297,8 +273,8 @@ public:
|
|||
dest_arguments.emplace_back(destination);
|
||||
dest_arguments.emplace_back(
|
||||
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
|
||||
dest_arguments.emplace_back(node->stride());
|
||||
dest_arguments.emplace_back(node->storage_offset());
|
||||
dest_arguments.emplace_back(node->stride);
|
||||
dest_arguments.emplace_back(node->storage_offset);
|
||||
TorchMlirOpVector as_strided_out =
|
||||
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
|
@ -307,52 +283,19 @@ public:
|
|||
return {destination};
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerBatchNorm(const torch::lazy::NativeBatchNormForward* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
|
||||
}
|
||||
arguments.emplace_back(node->training());
|
||||
arguments.emplace_back(node->momentum());
|
||||
arguments.emplace_back(node->eps());
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerBatchNormBackward(const torch::lazy::NativeBatchNormBackward* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
|
||||
}
|
||||
const auto& operands = node->operands();
|
||||
c10::optional<at::Tensor> null_arg;
|
||||
if (operands.size() == 5) {
|
||||
arguments.emplace_back(null_arg);
|
||||
arguments.emplace_back(null_arg);
|
||||
}
|
||||
for (size_t i = 3; i < operands.size(); ++i) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
|
||||
}
|
||||
arguments.emplace_back(node->training());
|
||||
arguments.emplace_back(node->eps());
|
||||
arguments.emplace_back(node->output_mask());
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->dtype());
|
||||
arguments.emplace_back(node->dtype);
|
||||
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size());
|
||||
arguments.emplace_back(node->size);
|
||||
auto expand_out = LowerBuiltin(node, arguments);
|
||||
if (node->is_scalar_expand()) {
|
||||
if (node->is_scalar_expand) {
|
||||
// The aten::expand operations sets all strides to 0 when the original
|
||||
// of rank 0. This leads to false positives when checking for internal
|
||||
// memory overlap, because at::has_internal_overlap returns
|
||||
|
@ -366,8 +309,8 @@ public:
|
|||
TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) {
|
||||
const torch::lazy::Output& input = node->operand(0);
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(input);
|
||||
const auto& base_indices = node->base_indices();
|
||||
const auto& sizes = node->sizes();
|
||||
const auto& base_indices = node->base_indices;
|
||||
const auto& sizes = node->sizes;
|
||||
const torch::lazy::Shape& input_shape = input.shape();
|
||||
CHECK_EQ(sizes.size(), base_indices.size());
|
||||
CHECK_EQ(input_shape.dim(), base_indices.size());
|
||||
|
@ -383,12 +326,12 @@ public:
|
|||
TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->dims());
|
||||
arguments.push_back(node->dims);
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) {
|
||||
const at::Scalar& value = node->value();
|
||||
const at::Scalar& value = node->value;
|
||||
const torch::lazy::Shape& shape = node->shape();
|
||||
auto options =
|
||||
at::TensorOptions()
|
||||
|
@ -399,20 +342,19 @@ public:
|
|||
}
|
||||
|
||||
TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) {
|
||||
int64_t step = torch::lazy::Select::GetStride(
|
||||
node->start(), node->end(), node->stride());
|
||||
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
|
||||
return {GenerateSlice(
|
||||
/*base=*/base, /*dim=*/node->dim(),
|
||||
/*start=*/node->start(), /*end=*/node->end(),
|
||||
/*base=*/base, /*dim=*/node->dim,
|
||||
/*start=*/node->start, /*end=*/node->end,
|
||||
/*step=*/step)};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
if (node->dim() != -1) {
|
||||
arguments.push_back(node->dim());
|
||||
if (node->dim != -1) {
|
||||
arguments.push_back(node->dim);
|
||||
}
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
@ -421,11 +363,10 @@ public:
|
|||
LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
int64_t step = torch::lazy::Select::GetStride(
|
||||
node->start(), node->end(), node->stride());
|
||||
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
|
||||
torch::jit::Value* selected = GenerateSlice(
|
||||
/*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(),
|
||||
/*end=*/node->end(), /*step=*/step);
|
||||
/*base=*/dest, /*dim=*/node->dim, /*start=*/node->start,
|
||||
/*end=*/node->end, /*step=*/step);
|
||||
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
|
||||
return {dest};
|
||||
}
|
||||
|
@ -434,7 +375,7 @@ public:
|
|||
LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
const auto& base_indices = node->base_indices();
|
||||
const auto& base_indices = node->base_indices;
|
||||
const torch::lazy::Output& source_argument = node->operand(1);
|
||||
const torch::lazy::Shape& source_shape = source_argument.shape();
|
||||
CHECK_EQ(source_shape.dim(), base_indices.size());
|
||||
|
@ -453,14 +394,14 @@ public:
|
|||
TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->dim());
|
||||
arguments.push_back(node->dim);
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerView(const torch::lazy::View* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->output_size());
|
||||
arguments.push_back(node->output_size);
|
||||
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
//===- generic.cpp --------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "generic.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Generic::Generic(
|
||||
OpKind op,
|
||||
OpList operands,
|
||||
Shape shape,
|
||||
size_t num_outputs,
|
||||
hash_t hash_seed)
|
||||
: TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed),
|
||||
hash_seed_(hash_seed) {}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,39 @@
|
|||
//===- generic.h ----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../mlir_node.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
// Generic IR Node implementation for nodes which can simply be described by a
|
||||
// specific OpKind and a lowering function. IR nodes carrying
|
||||
// metadata should not be using this class TORCH_API (and have the metadata
|
||||
// captured by the LowerFn), but they should instead create a dedicated IR node.
|
||||
// Doing the former would limit IR introspection.
|
||||
class TORCH_API Generic : public TorchMlirNode {
|
||||
public:
|
||||
Generic(
|
||||
OpKind op,
|
||||
OpList operands,
|
||||
Shape shape,
|
||||
size_t num_outputs = 1,
|
||||
hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
|
||||
|
||||
private:
|
||||
hash_t hash_seed_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
Binary file not shown.
Loading…
Reference in New Issue