mirror of https://github.com/llvm/torch-mlir
parent
b9847b1904
commit
059041e0fe
|
@ -467,7 +467,8 @@ class GenTorchMlirLTC:
|
|||
node_base="torch::lazy::TorchMlirNode",
|
||||
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
|
||||
tensor_class=self.tensor_class,
|
||||
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
|
||||
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h",
|
||||
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
|
||||
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
|
||||
lazy_ir_generator=GenMlirLazyIr,
|
||||
)
|
||||
|
|
|
@ -3,12 +3,6 @@ blacklist:
|
|||
# It also doesn't have confusing `unsafe` argument.
|
||||
- _index_put_impl
|
||||
|
||||
# Ops with list of tensors output
|
||||
- split.Tensor
|
||||
- split_with_sizes
|
||||
- unbind.int
|
||||
- chunk
|
||||
|
||||
# Additional ops which autogen is supported for but don't compile yet
|
||||
- _convolution
|
||||
- detach
|
||||
|
@ -18,42 +12,28 @@ blacklist:
|
|||
|
||||
# Disabled for consistency with TS backend
|
||||
- lift_fresh_copy
|
||||
- new_empty
|
||||
- rsub
|
||||
- slice.Tensor # Disabled in favour of slice_copy.Tensor
|
||||
- zeros
|
||||
- ones
|
||||
- arange
|
||||
- arange.start
|
||||
- arange.start_step
|
||||
- fill.Scalar
|
||||
- scalar_tensor
|
||||
|
||||
# Disabled in favour of functionalized alternatives
|
||||
- _reshape_alias
|
||||
- expand
|
||||
- permute
|
||||
- select.int
|
||||
- squeeze
|
||||
- squeeze.dim
|
||||
- t
|
||||
- transpose.int
|
||||
- expand
|
||||
- squeeze
|
||||
- unsqueeze
|
||||
- view
|
||||
- slice.Tensor
|
||||
- split.Tensor
|
||||
- split_with_sizes
|
||||
- unbind.int
|
||||
|
||||
whitelist:
|
||||
# Enabled for consistency with TS backend
|
||||
- arange.start_out
|
||||
|
||||
# List of supported ops that we don't want to do the full codegen for
|
||||
supported:
|
||||
# - bernoulli
|
||||
# - bernoulli_
|
||||
- _to_copy
|
||||
- clone
|
||||
- empty.memory_format
|
||||
- empty_strided
|
||||
- fill_.Scalar
|
||||
- _unsafe_view
|
||||
- unbind_copy.int
|
||||
- split_copy.Tensor
|
||||
|
@ -80,10 +60,10 @@ supported:
|
|||
- _trilinear
|
||||
- linalg_pinv.atol_rtol_tensor
|
||||
- logsumexp.out
|
||||
- t
|
||||
|
||||
# List of ops that will take in symints for the size instead of ints
|
||||
symint:
|
||||
- empty.memory_format
|
||||
- new_empty_strided
|
||||
- expand_copy
|
||||
- narrow_copy
|
||||
|
@ -91,7 +71,6 @@ symint:
|
|||
- slice_copy.Tensor
|
||||
- split_copy.Tensor
|
||||
- slice_scatter
|
||||
- view
|
||||
- view_copy
|
||||
- as_strided_copy
|
||||
- as_strided_scatter
|
||||
|
|
|
@ -1384,7 +1384,6 @@ LTC_XFAIL_SET = {
|
|||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"VarMeanCorrectionModule_basic",
|
||||
"VarMeanCorrectionNoneModule_basic",
|
||||
"PrimsConvertElementTypeModule_basic",
|
||||
"ElementwisePreluModule_basic",
|
||||
"VarMeanBiasedModule_basic",
|
||||
"VarMeanUnbiasedModule_basic",
|
||||
|
|
|
@ -4490,6 +4490,56 @@ def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRandomOp : Torch_Op<"aten.random", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::random : (Tensor, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRandomOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenRandomOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRandomFromOp : Torch_Op<"aten.random.from", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$from,
|
||||
AnyTorchOptionalIntType:$to,
|
||||
AnyTorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRandomFromOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenRandomFromOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -8934,6 +8984,31 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenResizeOp : Torch_Op<"aten.resize", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::resize : (Tensor, int[], int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$size,
|
||||
AnyTorchOptionalIntType:$memory_format
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenResizeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenResizeOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -69,6 +69,7 @@ add_library(torch_mlir_ltc_backend SHARED
|
|||
backend_impl.cpp
|
||||
dynamic_ir.cpp
|
||||
mlir_node.cpp
|
||||
tensor.cpp
|
||||
ops/device_data.cpp
|
||||
ops/generic.cpp
|
||||
ops/index.cpp
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "generated/LazyIr.h"
|
||||
#include "generated/LazyNativeFunctions.h"
|
||||
#include "generated/shape_inference.h"
|
||||
#include "ops/to_copy.h"
|
||||
|
@ -143,32 +144,6 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
|
|||
|
||||
} // namespace
|
||||
|
||||
// at::Tensor LazyNativeFunctions::bernoulli(
|
||||
// const at::Tensor& self, c10::optional<at::Generator> generator) {
|
||||
// TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
// if (generator.has_value() && generator->defined()) {
|
||||
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value");
|
||||
// }
|
||||
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
// UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
// // return torch::lazy::CreateAtenFromLtcTensor(
|
||||
// // torch::lazy::bernoulli(self_tensor));
|
||||
// }
|
||||
|
||||
// at::Tensor& LazyNativeFunctions::bernoulli_(
|
||||
// at::Tensor& self, double p, c10::optional<at::Generator> generator) {
|
||||
// TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
// if (generator.has_value() && generator->defined()) {
|
||||
// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value");
|
||||
// }
|
||||
// auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
// UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
// // torch::lazy::bernoulli_(self_tensor, p);
|
||||
// // return self;
|
||||
// }
|
||||
|
||||
// clone is special in LT because we make it a no-op.
|
||||
// This should be safe to do, because every operator in the LT is functional.
|
||||
at::Tensor LazyNativeFunctions::clone(
|
||||
|
@ -352,64 +327,17 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
|||
}
|
||||
};
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty_symint(
|
||||
at::SymIntArrayRef sym_size,
|
||||
c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
// TODO: support this directly
|
||||
auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
|
||||
const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
|
||||
at::TensorOptions options = at::TensorOptions()
|
||||
.device(c10::Device(device_type))
|
||||
.layout(layout)
|
||||
.pinned_memory(pin_memory)
|
||||
.dtype(dtype);
|
||||
auto x_result = at::empty(size, options, memory_format);
|
||||
auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
|
||||
// See Note [Lazy Tensor Functionalization]
|
||||
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
|
||||
c10::DispatchKey::Functionalize)) {
|
||||
// Invariant: if the functionalization key is in the exclude set, then we're expected
|
||||
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
|
||||
return tensor;
|
||||
} else {
|
||||
auto wrapped = at::functionalization::impl::to_functional_tensor(tensor);
|
||||
return wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty_strided(
|
||||
at::IntArrayRef size, at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
at::Tensor t = empty_symint(
|
||||
c10::fromIntArrayRefSlow(size),
|
||||
dtype, layout, device, pin_memory, c10::nullopt);
|
||||
return t.as_strided(size, stride, /*storage_offset=*/0);
|
||||
}
|
||||
|
||||
at::Tensor&
|
||||
LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
torch::lazy::Value constant =
|
||||
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
|
||||
value, self_tensor->shape(), self_tensor->GetDevice());
|
||||
self_tensor->SetInPlaceIrValue(std::move(constant));
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::_unsafe_view(
|
||||
const at::Tensor& self, at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(t)>::call(self);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||
|
@ -643,9 +571,18 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint(
|
|||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return at::functionalization::
|
||||
functionalize_aten_op_symint<ATEN_OP(new_empty_strided)>::call(
|
||||
self, size, stride, dtype, layout, device, pin_memory);
|
||||
if (!device || device->type() == c10::DeviceType::Lazy) {
|
||||
return at::functionalization::functionalize_aten_op_symint<
|
||||
ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout,
|
||||
device, pin_memory);
|
||||
}
|
||||
// For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu")
|
||||
// we need to avoid explicit functionalization. To do that we create regular cpu tensors.
|
||||
at::Tensor t = at::empty_symint(
|
||||
size, (dtype ? dtype : c10::optional<at::ScalarType>(self.scalar_type())),
|
||||
(layout ? layout : c10::optional<at::Layout>(self.layout())), device,
|
||||
pin_memory, c10::nullopt);
|
||||
return t.as_strided_symint(size, stride, /*storage_offset=*/0);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::narrow_copy_symint(
|
||||
|
@ -729,4 +666,4 @@ at::Tensor& LazyNativeFunctions::logsumexp_out(
|
|||
void InitializeAtenBindings() {}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch
|
|
@ -265,6 +265,33 @@ std::vector<torch::lazy::Shape> compute_shape_eye(
|
|||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_arange(
|
||||
const at::Scalar& end, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
auto out_meta =
|
||||
at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_arange(
|
||||
const at::Scalar& start, const at::Scalar& end,
|
||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta),
|
||||
pin_memory);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_arange(
|
||||
const at::Scalar& start, const at::Scalar& end, const at::Scalar& step,
|
||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
auto out_meta = at::arange(start, end, step, dtype, layout,
|
||||
c10::Device(c10::kMeta), pin_memory);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_full(
|
||||
at::IntArrayRef size, const at::Scalar& fill_value,
|
||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||
|
@ -273,6 +300,44 @@ std::vector<torch::lazy::Shape> compute_shape_full(
|
|||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_ones(
|
||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return {
|
||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_zeros(
|
||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return {
|
||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_empty(
|
||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
return {
|
||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_empty_strided(
|
||||
at::IntArrayRef size, at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
return {
|
||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
|
||||
const at::Scalar& value) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
|
||||
const at::Tensor& value) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
|
@ -302,11 +367,24 @@ std::vector<torch::lazy::Shape> compute_shape_randint(
|
|||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_resize(
|
||||
const at::Tensor & self, at::IntArrayRef size,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
return {Shape(self.scalar_type(), size.vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_bernoulli(
|
||||
const at::Tensor& self, const at::Tensor &p,
|
||||
c10::optional<at::Generator> generator) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
|
||||
const at::Scalar & s, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,29 @@
|
|||
//===- tensor.cpp ---------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
|
||||
#include "tensor.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
at::Tensor CreateFunctionalizedAtenFromLtcTensor(
|
||||
const LazyTensorPtr& ltc_tensor) {
|
||||
at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor);
|
||||
if (!c10::impl::tls_is_dispatch_key_excluded(
|
||||
c10::DispatchKey::Functionalize) &&
|
||||
!at::functionalization::impl::isFunctionalTensor(tensor)) {
|
||||
return at::functionalization::impl::to_functional_tensor(tensor);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,24 @@
|
|||
//===- tensor.h -----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/core/tensor.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
// Ops like torch.ones/zeros etc. which produce new tensor as an output
|
||||
// should have explicit tensor functinoalization. Otherwise we can get
|
||||
// unfanctionalized primitives or in the worst case if we apply inplace
|
||||
// operations to unfunctionalized tensor it won't be captured in LTC graph.
|
||||
TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -28,6 +28,11 @@ using namespace torch::lazy;
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
/// Returns true if a string begins with another.
|
||||
inline bool beginswith(const std::string& s, const std::string& t) {
|
||||
return s.size() >= t.size() && s.compare(0, t.size(), t) == 0;
|
||||
}
|
||||
|
||||
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
|
||||
ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
|
||||
: device_type_(device_type) {}
|
||||
|
@ -104,7 +109,25 @@ public:
|
|||
//
|
||||
// JIT Execution adopted from:
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
|
||||
torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), "");
|
||||
std::shared_ptr<torch::jit::Graph> graph = mlir_computation->graph();
|
||||
for (auto* node : graph->nodes()) {
|
||||
// Convert any lazy devices to cpu devices to ensure
|
||||
// that the values are actually computed
|
||||
if (node->outputs().size() == 1 &&
|
||||
node->output()->type()->kind() ==
|
||||
c10::TypeKind::DeviceObjType) {
|
||||
auto value_sym = torch::jit::Symbol::attr("value");
|
||||
TORCH_CHECK(node->hasAttribute(value_sym),
|
||||
"Expected node to have 'value' attribute.");
|
||||
TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s,
|
||||
"Expected 'value' attribute to be a string.");
|
||||
if (beginswith(node->s(value_sym), "lazy")) {
|
||||
node->s_(value_sym, "cpu");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::jit::GraphExecutor graph_executor(graph, "");
|
||||
std::vector<torch::jit::IValue> stack;
|
||||
for (const auto& argument : arguments) {
|
||||
const auto mlir_data =
|
||||
|
|
|
@ -359,6 +359,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::random : (Tensor, Generator?) -> (Tensor)")
|
||||
emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
|
||||
|
@ -571,6 +573,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
|
||||
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")
|
||||
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
|
||||
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||
|
|
Loading…
Reference in New Issue