Fix LTC Decoupling (#815)

* Initial changes

* Fix up native functions

* Further fix decoupling

* Remove unnecessary ops

* Formatting and copyright banners:

* Add pytorch submodule
pull/1125/head
Jae Hoon (Antonio) Kim 2022-05-03 09:35:44 -04:00 committed by Henry Tu
parent cca9fe126e
commit 1bde00c73d
25 changed files with 729 additions and 217 deletions

13
.gitignore vendored
View File

@ -23,15 +23,10 @@ __pycache__
# Bazel
bazel-*
# Libraries
*.so
*.a
# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp
/python/torch_mlir/csrc/base_lazy_backend/generated
# Example backend
examples/ltc_backend/ltc_backend/_EXAMPLE_MLIR_BACKEND.cpython-37m-x86_64-linux-gnu.so

View File

@ -12,20 +12,20 @@ from textwrap import dedent
import yaml
TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
TORCH_DIR = TORCH_MLIR_DIR.parent.joinpath("pytorch")
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")
sys.path.append(str(TORCH_DIR.joinpath("tools")))
sys.path.append(str(TORCH_DIR))
# PyTorch's LTC backend autogen script
import codegen.dest.lazy_ir
import codegen.gen_lazy_tensor
from codegen.api.lazy import LazyIrSchema
from codegen.gen import get_grouped_native_functions, parse_native_yaml
from codegen.model import NativeFunctionsGroup
import torchgen.dest.lazy_ir
import torchgen.gen_lazy_tensor
from torchgen.api.lazy import LazyIrSchema
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.model import NativeFunctionsGroup
def isOptionalCType(arg):
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"
def generate_native_functions(
@ -33,11 +33,11 @@ def generate_native_functions(
):
print("Generating Native Functions Yaml")
native_yaml_path = TORCH_DIR.joinpath(
"aten", "src", "ATen", "native", "native_functions.yaml"
)
native_path = TORCH_DIR.joinpath("aten", "src", "ATen", "native")
native_yaml_path = native_path.joinpath("native_functions.yaml")
tags_yaml_path = native_path.joinpath("tags.yaml")
parsed_yaml = parse_native_yaml(native_yaml_path)
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
native_functions = parsed_yaml.native_functions
grouped_native_functions = get_grouped_native_functions(native_functions)
@ -57,6 +57,9 @@ def generate_native_functions(
# primarily view ops
supported = config.get("supported", [])
# List of non-native ops to do IR codegen for
non_native = config.get("non_native", [])
if which("rg") is not None: # use ripgrep if available as its much faster
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
else:
@ -105,6 +108,7 @@ def generate_native_functions(
"cpp_namespace": "torch::lazy",
"full_codegen": opnames,
"supported": sorted(supported_ops),
"non_native": non_native,
},
f,
default_flow_style=False,
@ -123,13 +127,13 @@ def generate_native_functions(
@dataclass(frozen=True)
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
class GenMlirLazyIr(torchgen.dest.GenLazyIR):
def lowering_function(self, f):
func = (
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
)
schema = LazyIrSchema(func)
def lowering_function(self, schema, declaration_only=True):
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"
if declaration_only:
return f"{signature};"
emplace_arguments = []
for arg in schema.positional_args:
@ -149,7 +153,7 @@ class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])
return f"""
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override {{
{signature} {{
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
@ -159,7 +163,7 @@ class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
}}
@ -178,13 +182,13 @@ def generate_backend(
def gen_fallback_code(*args, **kwargs):
return ""
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
torchgen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
codegen.gen_lazy_tensor.run_gen_lazy_tensor(
torchgen.gen_lazy_tensor.run_gen_lazy_tensor(
backend_name="TorchMlir",
aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")),
source_yaml=str(source_yaml),
output_dir=str(backend_path),
output_dir=str(backend_path.joinpath("generated")),
dry_run=False,
impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")),
node_base="torch::lazy::TorchMlirNode",
@ -192,7 +196,7 @@ def generate_backend(
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
shape_inference_hdr=str(backend_path.joinpath("LazyShapeInference.h")),
lazy_ir_cls=MlirLazyIr,
lazy_ir_generator=GenMlirLazyIr,
)
# Remove lazy_tensor_core imports
@ -201,7 +205,7 @@ def generate_backend(
"sed",
"-i",
"/lazy_tensor_core/d",
str(backend_path.joinpath("LazyNativeFunctions.cpp")),
str(backend_path.joinpath("generated", "LazyNativeFunctions.cpp")),
]
)
@ -240,14 +244,14 @@ def generate_backend(
- shape_inference_defs
)
if missing_defs:
backend_path.joinpath("GenLazyShapeInference.cpp").write_text(
backend_path.joinpath("generated", "GenLazyShapeInference.cpp").write_text(
dedent(
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel or shape definition
#include "LazyShapeInference.h"
#include "../utils/exception.h"
#include "../LazyShapeInference.h"
#include "../../utils/exception.h"
namespace torch {{
namespace lazy {{
{}

View File

@ -38,8 +38,7 @@ supported:
- empty
- expand
- fill_
- native_batch_norm
# - native_batch_norm_backward
- native_batch_norm_backward
- permute
- squeeze
- t
@ -50,4 +49,39 @@ additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
- _copy_from
- _copy_from_and_resize
- native_batch_norm_backward
# - native_batch_norm_backward
# List of non native ops that we only want to do IR node class generation for
non_native:
- func: device_data(std::shared_ptr<BackendData> data) -> Tensor
opkind: ltc_device_data
cache_shape: false
- func: scalar(at::Scalar value, at::ScalarType type) -> Tensor
opkind: at::prim::Constant
cache_shape: false
- func: expand(Tensor input, std::vector<int64_t> size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, std::vector<int64_t> output_size) -> Tensor
cache_shape: false
- func: cast(Tensor input, at::ScalarType dtype, optional<at::ScalarType> stype) -> Tensor
opkind: ltc_cast
cache_shape: false
# View ops only required until proper functionalization pass is introduced into LTC
- func: as_strided_view_update(Tensor target, Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
opkind: ltc_as_strided_view_update
- func: as_strided(Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
opkind: ltc_diagonal_view_update
cache_shape: false
- func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, std::vector<int64_t> base_indices) -> Tensor
opkind: ltc_narrow_view_update
- func: narrow(Tensor input, std::vector<int64_t> base_indices, std::vector<int64_t> sizes) -> Tensor
- func: permute(Tensor input, std::vector<int64_t> dims) -> Tensor
- func: resize(Tensor input, std::vector<int64_t> size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
opkind: ltc_select_view_update
cache_shape: false
- func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
- func: squeeze(Tensor input, int dim) -> Tensor
- func: unsqueeze(Tensor input, int dim) -> Tensor

Binary file not shown.

Binary file not shown.

View File

@ -12,7 +12,7 @@
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/utils/debug.h>

1
externals/pytorch vendored 160000

@ -0,0 +1 @@
Subproject commit 9f3d6a00a76c567d7c046eabc60ae7a578f7bbde

Binary file not shown.

View File

@ -21,14 +21,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(torch_mlir_ltc_backend SHARED
base_lazy_backend/backend_impl.cpp
base_lazy_backend/LazyNativeFunctions.cpp
base_lazy_backend/generated/LazyNativeFunctions.cpp
base_lazy_backend/generated/GenLazyShapeInference.cpp
base_lazy_backend/generated/RegisterLazy.cpp
base_lazy_backend/LazyShapeInference.cpp
base_lazy_backend/GenLazyShapeInference.cpp
base_lazy_backend/mlir_lowering_context.cpp
base_lazy_backend/mlir_native_functions.cpp
base_lazy_backend/mlir_node.cpp
base_lazy_backend/mlir_node_lowering.cpp
base_lazy_backend/RegisterLazy.cpp
)
add_dependencies(torch_mlir_ltc_backend

View File

@ -13,5 +13,25 @@
namespace torch {
namespace lazy {
std::vector<Shape> compute_shape_native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
std::vector<Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
if (running_mean.has_value()) {
shapes.emplace_back(
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
}
}
return shapes;
}
} // namespace lazy
} // namespace torch

View File

@ -23,6 +23,7 @@ namespace lazy {
// clang-format off
TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
@ -37,6 +38,8 @@ TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, con
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
@ -63,6 +66,7 @@ TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::op
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
@ -82,6 +86,9 @@ TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::opt
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<Shape> compute_shape_zero_(at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
// clang-format on

View File

@ -18,6 +18,7 @@
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "backend_impl.h"
#include "ir_builder.h"
#include "mlir_lowering_context.h"
namespace torch {
@ -72,6 +73,15 @@ TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
* */
void TorchMlirBackendImpl::PrepareToExit() const {}
/**
* IR Tracing
* */
const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const {
static const IrBuilder* builder = new TorchMlirIrBuilder();
return builder;
}
/**
* Data Transfer
* */
@ -95,6 +105,16 @@ BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
return std::make_shared<TorchMlirBackendData>(device, shape);
}
BackendDataPtr
TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
PRINT_FUNCTION();
auto* device_data_node = dynamic_cast<DeviceData*>(node);
if (!device_data_node) {
return nullptr;
}
return device_data_node->data;
}
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {

View File

@ -65,6 +65,12 @@ public:
* */
virtual void PrepareToExit() const override;
/**
* IR Tracing
* */
const IrBuilder* GetIrBuilder() const override;
/**
* Configuration
* */
@ -84,6 +90,10 @@ public:
virtual BackendDataPtr CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const override;
// Gets backend data if the node is a device data node. Otherwise returns
// nullptr.
virtual BackendDataPtr GetComputationDataFromNode(Node*) const override;
virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const override;

View File

@ -0,0 +1,74 @@
//===- dynamic_ir.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.cpp
//===----------------------------------------------------------------------===//
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
namespace torch {
namespace lazy {
DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed)
: TorchMlirNode(
op, operands, /*num_outputs=*/1,
/* hash_seed */ HashCombine(op.hash(), hash_seed)) {}
std::string DimensionNode::ToString() const { return "DimensionNode"; }
SizeNode::SizeNode(Value input, size_t dim)
: DimensionNode(
OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
MHash(dim)),
dim_(dim){};
int64_t SizeNode::getStaticValue() const {
return dynamic_cast<const TorchMlirNode*>(operand(0).node)
->shape(0)
.size(dim_);
}
std::string SizeNode::ToString() const { return "SizeNode"; }
SizeAdd::SizeAdd(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){};
int64_t SizeAdd::getStaticValue() const {
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() +
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
}
std::string SizeAdd::ToString() const { return "SizeAdd"; }
SizeMul::SizeMul(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){};
int64_t SizeMul::getStaticValue() const {
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() *
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
}
std::string SizeMul::ToString() const { return "SizeMul"; }
SizeDiv::SizeDiv(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){};
int64_t SizeDiv::getStaticValue() const {
TORCH_CHECK(
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue() !=
0,
"Can't divide a dimension by zero");
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() /
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
}
std::string SizeDiv::ToString() const { return "SizeDiv"; }
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,99 @@
//===- dynamic_ir.h -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.h
//===----------------------------------------------------------------------===//
#pragma once
#include <ATen/core/symbol.h>
#include <functional>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "mlir_node.h"
#include <c10/core/ScalarType.h>
#include <c10/util/Flags.h>
#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_metadata.h>
C10_DECLARE_bool(ltc_enable_dynamic_shapes);
namespace torch {
namespace lazy {
/**
* The goal of "dynamic" Nodes is to patch a hole in our tracing.
* Previously, if a user called `sizes` on a Tensor, it would leak out
* of our tracing system, as `sizes` returns a torch.Size or an int. To
* prevent this from happening, we introduce DimensionNode, a new type
* of Node that abstracts the operation of getting the dimensions of a
* Tensor.
*
* Consider the following example:
* ```
* numel = x.shape()[0] * x.shape()[1]
* ```
*
* Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode),
* and the multiplication of the two SizeNodes will be represented by
* a SizeMul (also a subclass of DimensionNode). Through this, we can
* prevent `numel` from being represented as a Python int and thus
* burned into the Graph.
*/
class TORCH_API DimensionNode : public lazy::TorchMlirNode {
public:
DimensionNode(OpKind op, OpList operands, hash_t hash_seed = kHashSeed);
bool isDynamic() { return false; }
std::string ToString() const override;
virtual int64_t getStaticValue() const = 0;
};
// Represents the result of calling `size` on a Tensor
class TORCH_API SizeNode : public DimensionNode {
public:
SizeNode(Value input, size_t dim);
int64_t getStaticValue() const override;
std::string ToString() const override;
size_t dim_ = 0;
};
class TORCH_API SizeAdd : public DimensionNode {
public:
SizeAdd(Value a, Value b);
int64_t getStaticValue() const override;
std::string ToString() const override;
};
class TORCH_API SizeMul : public DimensionNode {
public:
SizeMul(Value a, Value b);
int64_t getStaticValue() const override;
std::string ToString() const override;
};
class TORCH_API SizeDiv : public DimensionNode {
public:
SizeDiv(Value a, Value b);
int64_t getStaticValue() const override;
std::string ToString() const override;
};
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,65 @@
//===- ir_builder.h -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ir_builder.h
//===----------------------------------------------------------------------===//
#pragma once
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include "dynamic_ir.h"
#include "generated/LazyNonNativeIr.h"
#include "mlir_node.h"
#include "ops/generic.h"
// This file contains the TorchMlir IrBuilder
namespace torch {
namespace lazy {
// clang-format off
struct TorchMlirIrBuilder : IrBuilder {
NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) const override { return MakeNode<DeviceData>(data); }
NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { return MakeNode<Scalar>(value, type); }
NodePtr MakeExpand(const Value& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand) const override { return MakeNode<Expand>(input0, size, is_scalar_expand); }
NodePtr MakeView(const Value& input0, const std::vector<int64_t>& output_size) const override { return MakeNode<View>(input0, output_size); }
NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype = c10::nullopt) const override { return MakeNode<Cast>(input0, dtype, stype); }
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TensorList>(inputs); }
NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const override { return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed); }
// view ops
NodePtr MakeAsStridedViewUpdate(const Value& input0, const Value& input1, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) const override { return MakeNode<AsStridedViewUpdate>(input0, input1, size, stride, storage_offset); }
NodePtr MakeAsStrided(const Value& input0, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) const override { return MakeNode<AsStrided>(input0, size, stride, storage_offset); }
NodePtr MakeDiagonalViewUpdate(const Value& input0, const Value& input1, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { return MakeNode<DiagonalViewUpdate>(input0, input1, offset, dim1, dim2); }
NodePtr MakeDiagonal(const Value& input0, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { return MakeNode<Diagonal>(input0, offset, dim1, dim2); }
NodePtr MakeNarrowViewUpdate(const Value& input0, const Value& input1, const std::vector<int64_t>& base_indices) const override { return MakeNode<NarrowViewUpdate>(input0, input1, base_indices); }
NodePtr MakeNarrow(const Value& input0, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes) const override { return MakeNode<Narrow>(input0, base_indices, sizes); }
NodePtr MakePermute(const Value& input0, const std::vector<int64_t>& dims) const override { return MakeNode<Permute>(input0, dims); }
NodePtr MakeResize(const Value& input0, const std::vector<int64_t>& size) const override { return MakeNode<Resize>(input0, size); }
NodePtr MakeSelectViewUpdate(const Value& input0, const Value& input1, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { return MakeNode<SelectViewUpdate>(input0, input1, dim, start, end, stride); }
NodePtr MakeSelect(const Value& input0, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { return MakeNode<Select>(input0, dim, start, end, stride); }
NodePtr MakeSqueeze(const Value& input0, const int& dim) const override { return MakeNode<Squeeze>(input0, dim); }
NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const override { return MakeNode<Unsqueeze>(input0, dim); }
// dynamic ir nodes
NodePtr MakeSizeNode(const Value& input, size_t dim) const override { return MakeNode<SizeNode>(input, dim); }
NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { return MakeNode<SizeAdd>(a, b); }
NodePtr MakeSizeMul(const Value& a, const Value& b) const override { return MakeNode<SizeMul>(a, b); }
NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { return MakeNode<SizeDiv>(a, b); }
};
// clang-format on
} // namespace lazy
} // namespace torch

View File

@ -148,12 +148,13 @@ void TorchMlirLoweringContext::AssignOutputOp(
const Output& output, torch::jit::Value* op) {
PRINT_FUNCTION();
auto torch_mlir_node =
NodeCast<TorchMlirNode>(output.node, output.node->op());
if (!torch_mlir_node->getPythonStacktrace().empty()) {
op->node()->s_(
c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace());
}
// TODO (antoniojkim): Do we need this?
// auto torch_mlir_node =
// NodeCast<TorchMlirNode>(output.node, output.node->op());
// if (!torch_mlir_node->getPythonStacktrace().empty()) {
// op->node()->s_(
// c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace());
// }
emitted_outputs_[output] = std::move(op);
}
@ -301,7 +302,7 @@ unsigned TorchMlirComputation::num_results() const { return num_results_; }
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }
std::string TorchMlirComputation::to_string() const {
const std::string TorchMlirComputation::to_string() const {
// Since we use the C-MLIR API, we need to use a callback to print.
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
// user_data is a void ptr to some data structure of our choice -- in this

View File

@ -149,7 +149,7 @@ public:
MlirOperation func_op() const;
std::string to_string() const;
const std::string to_string() const;
private:
std::vector<std::string> parameter_names_;

View File

@ -10,16 +10,18 @@
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
//===----------------------------------------------------------------------===//
#include <ATen/InferSize.h>
#include <ATen/Operators.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/result_type.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <torch/csrc/lazy/core/tensor_aten_ops.h>
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
#include <torch/library.h>
#include "ATen/MetaFunctions.h"
@ -27,8 +29,8 @@
#include "../utils/exception.h"
#include "../utils/sys_utils.h"
#include "LazyNativeFunctions.h"
#include "LazyShapeInference.h"
#include "generated/LazyNativeFunctions.h"
namespace torch {
namespace lazy {
@ -110,6 +112,60 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
return torch::lazy::atenDeviceToBackendDevice(*device);
}
torch::lazy::Value MaybeExpand(
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
if (input.shape().sizes() == target_shape.sizes()) {
return input;
}
return torch::lazy::MakeExpand(
input, target_shape.sizes().vec(),
/*is_scalar_expand=*/false);
}
std::vector<int64_t> GetExpandDimensions(
const torch::lazy::Shape& shape, std::vector<int64_t> dimensions) {
CHECK_GE(dimensions.size(), shape.dim()) << shape;
int64_t base = dimensions.size() - shape.dim();
for (size_t i = 0; i < shape.dim(); ++i) {
if (dimensions[base + i] == -1) {
dimensions[base + i] = shape.size(i);
}
}
return dimensions;
}
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
if (input->GetDevice() == src->GetDevice()) {
torch::lazy::Value copy_value;
if (input->dtype() == src->dtype()) {
copy_value = src->GetIrValue();
} else {
copy_value = torch::lazy::MakeCast(
src->GetIrValue(), input->dtype(), src->dtype());
}
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
} else {
auto input_shape = input->shape();
at::Tensor src_tensor = src->ToTensor(/*detached=*/true);
if (src_tensor.sizes() != input_shape.Get().sizes()) {
src_tensor = src_tensor.expand(input_shape.Get().sizes().vec());
}
input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false);
}
}
torch::lazy::LazyTensorPtr create_view(
const torch::lazy::LazyTensorPtr& input,
c10::ArrayRef<int64_t> output_size) {
auto input_shape = input->shape().Get();
torch::lazy::Shape shape = torch::lazy::Shape(
input_shape.scalar_type(),
at::infer_size(output_size, input_shape.numel()));
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
return input->CreateViewTensor(std::move(view_info));
}
} // namespace
// at::Tensor LazyNativeFunctions::bernoulli(
@ -209,7 +265,7 @@ at::Tensor LazyNativeFunctions::_copy_from(
dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
}
} else {
torch::lazy::copy_(dst_tensor, self_tensor);
copy_(dst_tensor, self_tensor);
auto* impl =
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
impl->set_tensor(dst_tensor);
@ -260,115 +316,106 @@ at::Tensor LazyNativeFunctions::empty(
at::Tensor LazyNativeFunctions::expand(
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
TORCH_LAZY_FN_COUNTER("lazy::");
UNIMPLEMENTED_FUNCTION_ERROR();
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec()));
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
auto input_shape = self_tensor->shape();
auto output = torch::lazy::LazyTensor::Create(
torch::lazy::MakeExpand(
self_tensor->GetIrValue(),
GetExpandDimensions(input_shape.Get(), std::move(size.vec())),
/*is_scalar_expand=*/false),
self_tensor->GetDevice());
output->SetStorage(self_tensor->Storage());
return torch::lazy::CreateAtenFromLtcTensor(output);
}
at::Tensor&
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
torch::lazy::fill_(self_tensor, value);
torch::lazy::Value constant =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
value, self_tensor->shape(), self_tensor->GetDevice());
self_tensor->SetInPlaceIrValue(std::move(constant));
return self;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
LazyNativeFunctions::native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto input_tensor = torch::lazy::TryGetLtcTensor(input);
const torch::lazy::BackendDevice& device = input_tensor->GetDevice();
auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device);
auto running_var_tensor = GetOrCreateLtcTensor(running_var, device);
auto outputs = torch::lazy::native_batch_norm(
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
GetOrCreateLtcTensor(bias, device), running_mean_tensor,
running_var_tensor, training, momentum, eps);
return std::make_tuple(
torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)),
torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)),
torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs)));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
LazyNativeFunctions::native_batch_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
std::array<bool, 3> output_mask) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out);
const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice();
torch::lazy::LazyTensorPtr null_tensor;
bool running_stats = running_mean && running_mean->defined();
CHECK_EQ(running_var && running_var->defined(), running_stats);
auto gradients = torch::lazy::native_batch_norm_backward(
torch::lazy::TryGetLtcTensor(grad_out),
torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device),
running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor,
running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor,
GetOrCreateLtcTensor(save_mean, device),
GetOrCreateLtcTensor(save_invstd, device), train, eps, output_mask);
at::Tensor undefined;
return std::make_tuple(
output_mask[0]
? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients))
: undefined,
output_mask[1]
? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients))
: undefined,
output_mask[2]
? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients))
: undefined);
}
at::Tensor
LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
UNIMPLEMENTED_FUNCTION_ERROR();
auto input_shape = self_tensor->shape();
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape,
torch::lazy::GetCanonicalDimensionIndices(
torch::lazy::ToI64Vector(dims), input_shape.Get().dim()));
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims)));
self_tensor->CreateViewTensor(std::move(view_info)));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self)));
return squeeze(self, -1);
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto input_shape = self_tensor->shape();
int64_t squeeze_dim = -1;
if (dim != -1) {
squeeze_dim =
torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
}
auto output_dimensions =
BuildSqueezedDimensions(input_shape.Get().sizes(), squeeze_dim);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim));
create_view(self_tensor, output_dimensions));
}
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto input_shape = self_tensor->shape();
auto permute_dims = torch::lazy::MakeTransposePermutation(
/*dim0=*/0, /*dim1=*/1, /*rank=*/input_shape.Get().dim());
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1));
self_tensor->CreateViewTensor(std::move(view_info)));
}
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto input_shape = self_tensor->shape();
int64_t squeeze_dim =
torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim() + 1);
auto dimensions =
BuildUnsqueezedDimensions(input_shape.Get().sizes(), squeeze_dim);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim));
create_view(self_tensor, dimensions));
}
at::Tensor
LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto input_shape = self_tensor->shape().Get();
torch::lazy::Shape shape = torch::lazy::Shape(
input_shape.scalar_type(),
at::infer_size(torch::lazy::ToI64Vector(size), input_shape.numel()));
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size)));
self_tensor->CreateViewTensor(std::move(view_info)));
}
void InitializeAtenBindings() {}

View File

@ -10,18 +10,92 @@
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.cpp
//===----------------------------------------------------------------------===//
#include <torch/csrc/lazy/core/cache.h>
#include "../utils/exception.h"
#include "mlir_node.h"
#include "../utils/exception.h"
namespace torch {
namespace lazy {
namespace {
hash_t OperandHashes(
const OpList& operands, const c10::ArrayRef<Shape>& shapes,
const hash_t& seed, bool bakeInSizes) {
hash_t hash = seed;
for (auto& operand : operands) {
if (!operand) {
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
continue;
}
auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
hash = HashCombine(hash, operand_hash);
}
for (auto& shape : shapes) {
hash = HashCombine(hash, shape.hash(bakeInSizes));
}
return hash;
}
} // namespace
TorchMlirNode::TorchMlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
hash_t hash_seed)
: Node(op, operands, std::move(shapes), num_outputs) {
hash_seed = HashCombine(op.hash(), hash_seed);
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
dag_hash_ =
(enableDynamicShape()
? OperandHashes(operands, this->shapes(), hash_seed, false)
: shape_hash_);
}
TorchMlirNode::TorchMlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
size_t num_outputs, hash_t hash_seed)
: TorchMlirNode(
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
addComputedShape(shape_fn);
}
TorchMlirNode::TorchMlirNode(
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
: TorchMlirNode(
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
TorchMlirNode::TorchMlirNode(
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
: TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
hash_t TorchMlirNode::hash() const { return dag_hash_; }
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
TorchMlirOpVector TorchMlirNode::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
return {};
}
TensorList::TensorList(OpList values)
: TorchMlirNode(
/*op=*/tensor_list_opkind,
/*operands=*/values,
/*shapes=*/std::vector<Shape>(),
/*num_outputs=*/1,
/*hash_seed=*/kHashSeed) {}
torch::lazy::TorchMlirOpVector TensorList::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::Value*> tensor_list;
CHECK(!operands().empty());
for (const torch::lazy::Output& operand : operands()) {
tensor_list.emplace_back(loctx->GetOutputOp(operand));
}
auto graph = function->graph();
auto listnode =
graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list));
return {listnode->output()};
}
} // namespace lazy
} // namespace torch

View File

@ -27,10 +27,63 @@ namespace lazy {
class TORCH_API TorchMlirNode : public torch::lazy::Node {
public:
using torch::lazy::Node::Node;
TorchMlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs, hash_t hash_seed = kHashSeed);
TorchMlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
size_t num_outputs, hash_t hash_seed = kHashSeed);
TorchMlirNode(
OpKind op, OpList operands, size_t num_outputs,
hash_t hash_seed = kHashSeed);
TorchMlirNode(
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed);
hash_t hash() const override;
hash_t shapeHash() const override;
virtual TorchMlirOpVector
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
private:
// The hash of the dag WITH size info. Used for shape caching
hash_t shape_hash_;
// The hash of the dag used to look up the compiled graph by a hash
// in this case, we will use the dag hash WITHOUT size info if dynamic shape
// is enabled and use the dag hash WITH size info otherwise.
hash_t dag_hash_;
};
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise
const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
// TensorList represents an at::TensorList which is a vector[Tensor] but is also
// a first-class IValue and can be fed as a single input to a TS program. It is
// much easier to handle TensorLists in Lazy Tensor code if they are represented
// as a single Node so there can be more than one TensorList and more than one
// Tensor side-by-side as operands to an op.
//
// Note: shape is undefined for TensorList. We assert in some places that
// #shapes matches #outputs and this stems from
// the fact that currently all IR nodes represent tensors (there is no
// type system for this IR). Becuase of this, TensorList is a bit of a
// hack.
//
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and
// then implement it as NotImplemented for TensorList, also fixing the assertion
// that would fail.
struct TORCH_API TensorList : public TorchMlirNode {
TensorList() = delete;
TensorList(OpList values);
torch::lazy::TorchMlirOpVector Lower(
TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const override;
};
} // namespace lazy

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir_node_lowering.h"
#include "generated/LazyNonNativeIr.h"
#include "mlir_lowering_context.h"
#include "mlir_node.h"
@ -21,26 +22,11 @@
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/core/permutation_util.h>
#include <torch/csrc/lazy/core/internal_ops/cast.h>
#include <torch/csrc/lazy/core/internal_ops/device_data.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ops/batch_norm_ops.h>
#include <torch/csrc/lazy/core/ops/expand.h>
#include <torch/csrc/lazy/core/ops/scalar.h>
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
#include <torch/csrc/lazy/core/view_ops/as_strided_view_update.h>
#include <torch/csrc/lazy/core/view_ops/narrow.h>
#include <torch/csrc/lazy/core/view_ops/narrow_view_update.h>
#include <torch/csrc/lazy/core/view_ops/permute.h>
#include <torch/csrc/lazy/core/view_ops/select.h>
#include <torch/csrc/lazy/core/view_ops/select_view_update.h>
#include <torch/csrc/lazy/core/view_ops/squeeze.h>
#include <torch/csrc/lazy/core/view_ops/unsqueeze.h>
#include <torch/csrc/lazy/core/view_ops/view.h>
namespace torch {
namespace lazy {
@ -195,16 +181,6 @@ public:
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
return LowerBuiltin(node, arguments);
}
if (node->op().op == at::aten::native_batch_norm) {
return LowerBatchNorm(
torch::lazy::NodeCast<torch::lazy::NativeBatchNormForward>(
node, torch::lazy::OpKind(at::aten::native_batch_norm)));
}
if (node->op().op == at::aten::native_batch_norm_backward) {
return LowerBatchNormBackward(
torch::lazy::NodeCast<torch::lazy::NativeBatchNormBackward>(
node, torch::lazy::OpKind(at::aten::native_batch_norm_backward)));
}
if (node->op().op == at::aten::expand) {
return LowerExpand(torch::lazy::NodeCast<torch::lazy::Expand>(
node, torch::lazy::OpKind(at::aten::expand)));
@ -237,14 +213,14 @@ public:
const torch::lazy::DeviceData* device_data_node =
torch::lazy::NodeCast<torch::lazy::DeviceData>(
node, *torch::lazy::ltc_device_data);
auto infoptr = device_data_node->data()->info();
auto infoptr = device_data_node->data->info();
auto deviceDataInfoPtr =
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
if (GRAPH_DUMP_ENABLED) {
LOG(ERROR) << "Lowering device data node, tensor id "
<< deviceDataInfoPtr->tensor_id << std::endl;
}
return {loctx()->GetParameter(device_data_node->data())};
return {loctx()->GetParameter(device_data_node->data)};
}
std::vector<torch::jit::NamedValue> arguments;
@ -278,9 +254,9 @@ public:
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size());
arguments.emplace_back(node->stride());
arguments.emplace_back(node->storage_offset());
arguments.emplace_back(node->size);
arguments.emplace_back(node->stride);
arguments.emplace_back(node->storage_offset);
TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments);
CHECK_EQ(as_strided_out.size(), 1);
return {GenerateClone(as_strided_out.front())};
@ -297,8 +273,8 @@ public:
dest_arguments.emplace_back(destination);
dest_arguments.emplace_back(
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
dest_arguments.emplace_back(node->stride());
dest_arguments.emplace_back(node->storage_offset());
dest_arguments.emplace_back(node->stride);
dest_arguments.emplace_back(node->storage_offset);
TorchMlirOpVector as_strided_out =
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
CHECK_EQ(as_strided_out.size(), 1);
@ -307,52 +283,19 @@ public:
return {destination};
}
TorchMlirOpVector
LowerBatchNorm(const torch::lazy::NativeBatchNormForward* node) {
std::vector<torch::jit::NamedValue> arguments;
for (size_t i = 0; i < 5; ++i) {
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
}
arguments.emplace_back(node->training());
arguments.emplace_back(node->momentum());
arguments.emplace_back(node->eps());
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector
LowerBatchNormBackward(const torch::lazy::NativeBatchNormBackward* node) {
std::vector<torch::jit::NamedValue> arguments;
for (size_t i = 0; i < 3; ++i) {
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
}
const auto& operands = node->operands();
c10::optional<at::Tensor> null_arg;
if (operands.size() == 5) {
arguments.emplace_back(null_arg);
arguments.emplace_back(null_arg);
}
for (size_t i = 3; i < operands.size(); ++i) {
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
}
arguments.emplace_back(node->training());
arguments.emplace_back(node->eps());
arguments.emplace_back(node->output_mask());
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dtype());
arguments.emplace_back(node->dtype);
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
}
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size());
arguments.emplace_back(node->size);
auto expand_out = LowerBuiltin(node, arguments);
if (node->is_scalar_expand()) {
if (node->is_scalar_expand) {
// The aten::expand operations sets all strides to 0 when the original
// of rank 0. This leads to false positives when checking for internal
// memory overlap, because at::has_internal_overlap returns
@ -366,8 +309,8 @@ public:
TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) {
const torch::lazy::Output& input = node->operand(0);
torch::jit::Value* base = loctx()->GetOutputOp(input);
const auto& base_indices = node->base_indices();
const auto& sizes = node->sizes();
const auto& base_indices = node->base_indices;
const auto& sizes = node->sizes;
const torch::lazy::Shape& input_shape = input.shape();
CHECK_EQ(sizes.size(), base_indices.size());
CHECK_EQ(input_shape.dim(), base_indices.size());
@ -383,12 +326,12 @@ public:
TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->dims());
arguments.push_back(node->dims);
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) {
const at::Scalar& value = node->value();
const at::Scalar& value = node->value;
const torch::lazy::Shape& shape = node->shape();
auto options =
at::TensorOptions()
@ -399,20 +342,19 @@ public:
}
TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) {
int64_t step = torch::lazy::Select::GetStride(
node->start(), node->end(), node->stride());
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
return {GenerateSlice(
/*base=*/base, /*dim=*/node->dim(),
/*start=*/node->start(), /*end=*/node->end(),
/*base=*/base, /*dim=*/node->dim,
/*start=*/node->start, /*end=*/node->end,
/*step=*/step)};
}
TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
if (node->dim() != -1) {
arguments.push_back(node->dim());
if (node->dim != -1) {
arguments.push_back(node->dim);
}
return LowerBuiltin(node, arguments);
}
@ -421,11 +363,10 @@ public:
LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
int64_t step = torch::lazy::Select::GetStride(
node->start(), node->end(), node->stride());
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
torch::jit::Value* selected = GenerateSlice(
/*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(),
/*end=*/node->end(), /*step=*/step);
/*base=*/dest, /*dim=*/node->dim, /*start=*/node->start,
/*end=*/node->end, /*step=*/step);
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
return {dest};
}
@ -434,7 +375,7 @@ public:
LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
const auto& base_indices = node->base_indices();
const auto& base_indices = node->base_indices;
const torch::lazy::Output& source_argument = node->operand(1);
const torch::lazy::Shape& source_shape = source_argument.shape();
CHECK_EQ(source_shape.dim(), base_indices.size());
@ -453,14 +394,14 @@ public:
TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->dim());
arguments.push_back(node->dim);
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerView(const torch::lazy::View* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->output_size());
arguments.push_back(node->output_size);
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
}

View File

@ -0,0 +1,28 @@
//===- generic.cpp --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.cpp
//===----------------------------------------------------------------------===//
#include "generic.h"
namespace torch {
namespace lazy {
Generic::Generic(
OpKind op,
OpList operands,
Shape shape,
size_t num_outputs,
hash_t hash_seed)
: TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed),
hash_seed_(hash_seed) {}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,39 @@
//===- generic.h ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.h
//===----------------------------------------------------------------------===//
#pragma once
#include "../mlir_node.h"
namespace torch {
namespace lazy {
// Generic IR Node implementation for nodes which can simply be described by a
// specific OpKind and a lowering function. IR nodes carrying
// metadata should not be using this class TORCH_API (and have the metadata
// captured by the LowerFn), but they should instead create a dedicated IR node.
// Doing the former would limit IR introspection.
class TORCH_API Generic : public TorchMlirNode {
public:
Generic(
OpKind op,
OpList operands,
Shape shape,
size_t num_outputs = 1,
hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
private:
hash_t hash_seed_;
};
} // namespace lazy
} // namespace torch