[LTC] Support torch.ones/zeros/arange ops (#2440)

pull/2479/head snapshot-20230922.969
Gleb Kazantaev 2023-09-21 13:25:14 -04:00 committed by GitHub
parent b9847b1904
commit 059041e0fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 262 additions and 113 deletions

View File

@ -467,7 +467,8 @@ class GenTorchMlirLTC:
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
tensor_class=self.tensor_class,
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h",
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
lazy_ir_generator=GenMlirLazyIr,
)

View File

@ -3,12 +3,6 @@ blacklist:
# It also doesn't have confusing `unsafe` argument.
- _index_put_impl
# Ops with list of tensors output
- split.Tensor
- split_with_sizes
- unbind.int
- chunk
# Additional ops which autogen is supported for but don't compile yet
- _convolution
- detach
@ -18,42 +12,28 @@ blacklist:
# Disabled for consistency with TS backend
- lift_fresh_copy
- new_empty
- rsub
- slice.Tensor # Disabled in favour of slice_copy.Tensor
- zeros
- ones
- arange
- arange.start
- arange.start_step
- fill.Scalar
- scalar_tensor
# Disabled in favour of functionalized alternatives
- _reshape_alias
- expand
- permute
- select.int
- squeeze
- squeeze.dim
- t
- transpose.int
- expand
- squeeze
- unsqueeze
- view
- slice.Tensor
- split.Tensor
- split_with_sizes
- unbind.int
whitelist:
# Enabled for consistency with TS backend
- arange.start_out
# List of supported ops that we don't want to do the full codegen for
supported:
# - bernoulli
# - bernoulli_
- _to_copy
- clone
- empty.memory_format
- empty_strided
- fill_.Scalar
- _unsafe_view
- unbind_copy.int
- split_copy.Tensor
@ -80,10 +60,10 @@ supported:
- _trilinear
- linalg_pinv.atol_rtol_tensor
- logsumexp.out
- t
# List of ops that will take in symints for the size instead of ints
symint:
- empty.memory_format
- new_empty_strided
- expand_copy
- narrow_copy
@ -91,7 +71,6 @@ symint:
- slice_copy.Tensor
- split_copy.Tensor
- slice_scatter
- view
- view_copy
- as_strided_copy
- as_strided_scatter

View File

@ -1384,7 +1384,6 @@ LTC_XFAIL_SET = {
"ConvolutionBackwardModule2DPadded_basic",
"VarMeanCorrectionModule_basic",
"VarMeanCorrectionNoneModule_basic",
"PrimsConvertElementTypeModule_basic",
"ElementwisePreluModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic",

View File

@ -4490,6 +4490,56 @@ def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [
}];
}
def Torch_AtenRandomOp : Torch_Op<"aten.random", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::random : (Tensor, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRandomOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenRandomOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenRandomFromOp : Torch_Op<"aten.random.from", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$from,
AnyTorchOptionalIntType:$to,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRandomFromOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRandomFromOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement,
HasValueSemantics,
@ -8934,6 +8984,31 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [
}];
}
def Torch_AtenResizeOp : Torch_Op<"aten.resize", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::resize : (Tensor, int[], int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenResizeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenResizeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
AllowsTypeRefinement
]> {

View File

@ -69,6 +69,7 @@ add_library(torch_mlir_ltc_backend SHARED
backend_impl.cpp
dynamic_ir.cpp
mlir_node.cpp
tensor.cpp
ops/device_data.cpp
ops/generic.cpp
ops/index.cpp

View File

@ -30,6 +30,7 @@
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/library.h>
#include "generated/LazyIr.h"
#include "generated/LazyNativeFunctions.h"
#include "generated/shape_inference.h"
#include "ops/to_copy.h"
@ -143,32 +144,6 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
} // namespace
// at::Tensor LazyNativeFunctions::bernoulli(
// const at::Tensor& self, c10::optional<at::Generator> generator) {
// TORCH_LAZY_FN_COUNTER("lazy::");
// if (generator.has_value() && generator->defined()) {
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value");
// }
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
// UNIMPLEMENTED_FUNCTION_ERROR();
// // return torch::lazy::CreateAtenFromLtcTensor(
// // torch::lazy::bernoulli(self_tensor));
// }
// at::Tensor& LazyNativeFunctions::bernoulli_(
// at::Tensor& self, double p, c10::optional<at::Generator> generator) {
// TORCH_LAZY_FN_COUNTER("lazy::");
// if (generator.has_value() && generator->defined()) {
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value");
// }
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
// UNIMPLEMENTED_FUNCTION_ERROR();
// // torch::lazy::bernoulli_(self_tensor, p);
// // return self;
// }
// clone is special in LT because we make it a no-op.
// This should be safe to do, because every operator in the LT is functional.
at::Tensor LazyNativeFunctions::clone(
@ -352,64 +327,17 @@ at::Tensor LazyNativeFunctions::_to_copy(
}
};
at::Tensor LazyNativeFunctions::empty_symint(
at::SymIntArrayRef sym_size,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
// TODO: support this directly
auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
at::TensorOptions options = at::TensorOptions()
.device(c10::Device(device_type))
.layout(layout)
.pinned_memory(pin_memory)
.dtype(dtype);
auto x_result = at::empty(size, options, memory_format);
auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
// See Note [Lazy Tensor Functionalization]
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize)) {
// Invariant: if the functionalization key is in the exclude set, then we're expected
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
return tensor;
} else {
auto wrapped = at::functionalization::impl::to_functional_tensor(tensor);
return wrapped;
}
}
at::Tensor LazyNativeFunctions::empty_strided(
at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty_symint(
c10::fromIntArrayRefSlow(size),
dtype, layout, device, pin_memory, c10::nullopt);
return t.as_strided(size, stride, /*storage_offset=*/0);
}
at::Tensor&
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
torch::lazy::Value constant =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
value, self_tensor->shape(), self_tensor->GetDevice());
self_tensor->SetInPlaceIrValue(std::move(constant));
return self;
}
at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self, at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size));
}
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("lazy::");
return at::functionalization::functionalize_aten_op<ATEN_OP(t)>::call(self);
}
std::vector<at::Tensor> LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
@ -643,9 +571,18 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint(
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return at::functionalization::
functionalize_aten_op_symint<ATEN_OP(new_empty_strided)>::call(
self, size, stride, dtype, layout, device, pin_memory);
if (!device || device->type() == c10::DeviceType::Lazy) {
return at::functionalization::functionalize_aten_op_symint<
ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout,
device, pin_memory);
}
// For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu")
// we need to avoid explicit functionalization. To do that we create regular cpu tensors.
at::Tensor t = at::empty_symint(
size, (dtype ? dtype : c10::optional<at::ScalarType>(self.scalar_type())),
(layout ? layout : c10::optional<at::Layout>(self.layout())), device,
pin_memory, c10::nullopt);
return t.as_strided_symint(size, stride, /*storage_offset=*/0);
}
at::Tensor LazyNativeFunctions::narrow_copy_symint(
@ -729,4 +666,4 @@ at::Tensor& LazyNativeFunctions::logsumexp_out(
void InitializeAtenBindings() {}
} // namespace lazy
} // namespace torch
} // namespace torch

View File

@ -265,6 +265,33 @@ std::vector<torch::lazy::Shape> compute_shape_eye(
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_arange(
const at::Scalar& end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
auto out_meta =
at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_arange(
const at::Scalar& start, const at::Scalar& end,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta),
pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_arange(
const at::Scalar& start, const at::Scalar& end, const at::Scalar& step,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
auto out_meta = at::arange(start, end, step, dtype, layout,
c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_full(
at::IntArrayRef size, const at::Scalar& fill_value,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
@ -273,6 +300,44 @@ std::vector<torch::lazy::Shape> compute_shape_full(
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_ones(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_zeros(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_empty(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_empty_strided(
at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
const at::Scalar& value) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
const at::Tensor& value) {
return {Shape(self.scalar_type(), self.sizes().vec())};
@ -302,11 +367,24 @@ std::vector<torch::lazy::Shape> compute_shape_randint(
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_resize(
const at::Tensor & self, at::IntArrayRef size,
c10::optional<at::MemoryFormat> memory_format) {
return {Shape(self.scalar_type(), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_bernoulli(
const at::Tensor& self, const at::Tensor &p,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
const at::Scalar & s, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,29 @@
//===- tensor.cpp ---------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include <ATen/FunctionalTensorWrapper.h>
#include "tensor.h"
namespace torch {
namespace lazy {
at::Tensor CreateFunctionalizedAtenFromLtcTensor(
const LazyTensorPtr& ltc_tensor) {
at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor);
if (!c10::impl::tls_is_dispatch_key_excluded(
c10::DispatchKey::Functionalize) &&
!at::functionalization::impl::isFunctionalTensor(tensor)) {
return at::functionalization::impl::to_functional_tensor(tensor);
}
return tensor;
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,24 @@
//===- tensor.h -----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <torch/csrc/lazy/core/tensor.h>
namespace torch {
namespace lazy {
// Ops like torch.ones/zeros etc. which produce new tensor as an output
// should have explicit tensor functinoalization. Otherwise we can get
// unfanctionalized primitives or in the worst case if we apply inplace
// operations to unfunctionalized tensor it won't be captured in LTC graph.
TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
} // namespace lazy
} // namespace torch

View File

@ -28,6 +28,11 @@ using namespace torch::lazy;
namespace torch {
namespace lazy {
/// Returns true if a string begins with another.
inline bool beginswith(const std::string& s, const std::string& t) {
return s.size() >= t.size() && s.compare(0, t.size(), t) == 0;
}
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
: device_type_(device_type) {}
@ -104,7 +109,25 @@ public:
//
// JIT Execution adopted from:
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), "");
std::shared_ptr<torch::jit::Graph> graph = mlir_computation->graph();
for (auto* node : graph->nodes()) {
// Convert any lazy devices to cpu devices to ensure
// that the values are actually computed
if (node->outputs().size() == 1 &&
node->output()->type()->kind() ==
c10::TypeKind::DeviceObjType) {
auto value_sym = torch::jit::Symbol::attr("value");
TORCH_CHECK(node->hasAttribute(value_sym),
"Expected node to have 'value' attribute.");
TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s,
"Expected 'value' attribute to be a string.");
if (beginswith(node->s(value_sym), "lazy")) {
node->s_(value_sym, "cpu");
}
}
}
torch::jit::GraphExecutor graph_executor(graph, "");
std::vector<torch::jit::IValue> stack;
for (const auto& argument : arguments) {
const auto mlir_data =

View File

@ -359,6 +359,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::random : (Tensor, Generator?) -> (Tensor)")
emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
@ -571,6 +573,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)