E2E HuggingFace Bert using LTC Backend (#912)

* Update native function definitions

* Add ops to support bert lowering

- Add empty_strided and as_strided

- Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy)

- Check for composite implicit ops and add device data IR

- Also fix codegen for functionalization

* Add autogen to CMakeList

* Remove PyTorch submodule

* Reduced BERT model size

* Print Mark Step status in Torch MLIR LTC debug string

* Apply fixes to work with latest upstream/main

- Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode

  Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor

* Update shape inference functions

- Fixed compute_shape_native_batch_norm when mean and var are uninitialized

  Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels.

- Implemented compute_shape_mul

- Fixed bug in reshape shape inference error message

* Get MLIR backend more consistent with TS backend

- Remove LazyNativeFunctions::_unsafe_view from autogen

- Blacklist ops to make JIT graph more like output of TS backend

- Print graph when SSA value has mismatch of types and results

- Remove normalize_index from LazyShapeInference

- Fix seeds for LTC example models

* Update and clean up shape inference functions

- Prune shape inference functions

- Add shape inference function for GenerateSlice

- Add shape inference function for GenerateCopy

Co-authored-by: Henry Tu <henry.tu@cerebras.net>
pull/1125/head
Jae Hoon (Antonio) Kim 2022-06-07 14:38:50 -04:00 committed by Henry Tu
parent 0c35e607b3
commit d9aee0d7a7
20 changed files with 826 additions and 413 deletions

1
.gitignore vendored
View File

@ -11,6 +11,7 @@ libtorch*
/build/
__pycache__
*.pyc
.pytype

View File

@ -1,9 +1,11 @@
import argparse
import hashlib
import importlib
import os
import subprocess
import sys
import warnings
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from shutil import which
@ -11,18 +13,16 @@ from textwrap import dedent
import yaml
TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")
sys.path.append(str(TORCH_DIR))
# PyTorch's LTC backend autogen script
import torchgen
import torchgen.dest.lazy_ir
import torchgen.gen_lazy_tensor
from torchgen.api.lazy import LazyIrSchema
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.model import NativeFunctionsGroup
TORCH_DIR = Path(importlib.util.find_spec('torch').origin).resolve().parent.parent
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
def isOptionalCType(arg):
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"
@ -42,20 +42,29 @@ def generate_native_functions(
grouped_native_functions = get_grouped_native_functions(native_functions)
def get_native_function_name(f):
func = f.func if hasattr(f, "func") else f.functional.func
return str(func.name)
func = f if hasattr(f, "func") else f.functional
return str(func.func.name), func
aten_funcs = set(map(get_native_function_name, grouped_native_functions))
def get_opnames(ops):
opnames = defaultdict(set)
for op in ops:
opname = op.split(".")[0]
opnames[opname].add(op)
return opnames
native_functions = dict(map(get_native_function_name, native_functions))
grouped_native_functions = dict(map(get_native_function_name, grouped_native_functions))
aten_funcs = get_opnames(set(grouped_native_functions.keys()))
with config_path.open() as f:
config = yaml.load(f, yaml.CLoader)
# List of unsupported ops in LTC autogen because of some error
blacklist = config.get("blacklist", [])
blacklist = set(config.get("blacklist", []))
# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported = config.get("supported", [])
supported = set(config.get("supported", []))
# List of non-native ops to do IR codegen for
non_native = config.get("non_native", [])
@ -65,49 +74,54 @@ def generate_native_functions(
else:
cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"]
output = (
subprocess.check_output(
torch_ops = set(
op[6:]
for op in subprocess.check_output(
cmd + [str(torch_ops_file)],
encoding="utf-8",
)
.strip()
.split(os.linesep)
)
torch_opnames = get_opnames(torch_ops)
# process ops list
ops = []
supported_ops = []
skipped = []
ops = set()
composite_implicit = set()
for op in output:
op = op[6:]
opname = op.split(".")[0]
if opname in blacklist or op in blacklist:
for op in torch_ops:
if op not in native_functions:
continue
if opname in supported:
supported_ops.append(op)
func = native_functions[op]
base = func.func.name.name.base
if base in blacklist or op in blacklist:
continue
if base in supported or op in supported:
continue
if op not in aten_funcs:
skipped.append(op)
continue
if func.has_composite_implicit_autograd_kernel and f"{op}_backward" not in torch_ops:
composite_implicit.add(op)
elif func.func.name.name.inplace:
for autogen in func.autogen:
if "functional" in autogen.overload_name:
ops.add(str(autogen))
else:
ops.add(op)
ops.append(op)
opnames = sorted(set(ops))
skipped = set(torch_ops) - ops - supported - composite_implicit
# Additional ops to support that are not supported by Torch-MLIR explicitly
supported_ops.extend(config.get("additional_ops", []))
supported |= set(config.get("additional_ops", []))
with out_file.open("w") as f:
yaml.dump(
{
"backend": "Lazy",
"cpp_namespace": "torch::lazy",
"full_codegen": opnames,
"supported": sorted(supported_ops),
"full_codegen": sorted(ops),
"supported": sorted(supported),
"non_native": non_native,
},
f,
@ -117,10 +131,15 @@ def generate_native_functions(
dedent(
"""
# Composite implicit ops (supported by Torch-MLIR but not differentiable)
{composite_implicit}
# Skipped ops (supported by Torch-MLIR but no equivalent native function)
{skipped}
"""
).format(
composite_implicit=os.linesep.join(f"# - {op}" for op in sorted(composite_implicit)),
skipped=os.linesep.join(f"# - {op}" for op in sorted(skipped)),
)
+ os.linesep.join(f"# - {op}" for op in sorted(skipped))
)
return parsed_yaml, grouped_native_functions
@ -129,11 +148,13 @@ def generate_native_functions(
@dataclass(frozen=True)
class GenMlirLazyIr(torchgen.dest.GenLazyIR):
def lowering_function(self, schema, declaration_only=True):
def lowering_function(self, schema):
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"
if declaration_only:
if schema.properties.LowerDeclOnly:
return f"{signature};"
elif not schema.properties.Lower:
return ""
emplace_arguments = []
for arg in schema.positional_args:
@ -213,7 +234,7 @@ def generate_backend(
import re
sig_re = re.compile(
r"std::vector<Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
r"std::vector<torch::lazy::Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
)
global_signatures = {}
@ -307,25 +328,30 @@ def main(args):
)
assert backend_path.is_dir()
torchgen_path = Path(torchgen.__path__[0]).resolve()
assert torchgen_path.is_dir()
prev_hash = None
hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash")
if hash_file.exists():
prev_hash = hash_file.read_text().strip()
m = hashlib.sha256()
m.update(script_path.read_bytes())
m.update(config_path.read_bytes())
m.update(torch_ops_file.read_bytes())
if native_functions.exists():
m.update(native_functions.read_bytes())
shape_inference_headers = backend_path.joinpath("LazyShapeInference.h")
if shape_inference_headers.exists():
m.update(shape_inference_headers.read_bytes())
shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp")
if shape_inference_defs.exists():
m.update(shape_inference_defs.read_bytes())
# Add file contents to hash
for path in (
script_path,
config_path,
torch_ops_file,
native_functions,
backend_path.joinpath("LazyShapeInference.h"),
backend_path.joinpath("LazyShapeInference.cpp"),
torchgen_path.joinpath("dest", "lazy_ir.py"),
torchgen_path.joinpath("api", "lazy.py"),
torchgen_path.joinpath("model.py"),
):
if path.exists():
m.update(path.read_bytes())
new_hash = m.hexdigest().strip()

View File

@ -1,22 +1,11 @@
blacklist:
# List of unsupported ops in LTC autogen because of some error
- arange # Error: Code below assumes there is at least one tensor arg
- contiguous # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- full # Error: Code below assumes there is at least one tensor arg
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
- index_put # Error: TODO not sure if there are other valid types to handle here
- index_put_ # Error: TODO not sure if there are other valid types to handle here
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
- ones # Error: Code below assumes there is at least one tensor arg
- ones_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- resize_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- stack # Error: TODO not sure if there are other valid types to handle here
- to.dtype # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- to.other # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- uniform_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- zeros # Error: Code below assumes there is at least one tensor arg
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
# Additional ops which autogen is supported for but don't compile yet
- detach
@ -24,26 +13,34 @@ blacklist:
- size
- where
- copy_
- _to_copy
- log_softmax # Not inherently differentiable. Needs to be decomposed.
- linear # Not inherently differentiable. Needs to be decomposed.
# Disabled for consistency with TS backend
- rsub
# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported:
# - bernoulli
# - bernoulli_
- as_strided
- as_strided_
- _to_copy
- cat
- clone
- empty
- empty.memory_format
- empty_strided
- expand
- fill_
- native_batch_norm_backward
- fill_.Scalar
- permute
- select.int
- slice.Tensor
- squeeze
- squeeze.dim
- t
- transpose.int
- unsqueeze
- view
- _unsafe_view
additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
@ -53,35 +50,38 @@ additional_ops:
# List of non native ops that we only want to do IR node class generation for
non_native:
- func: device_data(std::shared_ptr<BackendData> data) -> Tensor
opkind: ltc_device_data
cache_shape: false
- func: scalar(at::Scalar value, at::ScalarType type) -> Tensor
- func: scalar(Scalar value, ScalarType type) -> Tensor
opkind: at::prim::Constant
cache_shape: false
- func: expand(Tensor input, std::vector<int64_t> size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, std::vector<int64_t> output_size) -> Tensor
cache_shape: false
- func: cast(Tensor input, at::ScalarType dtype, optional<at::ScalarType> stype) -> Tensor
properties:
- ShapeCompute
- TreatScalarsAsConstants
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, int[] output_size) -> Tensor
properties:
- ShapeCompute
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
opkind: ltc_cast
cache_shape: false
properties:
- ShapeCompute
# View ops only required until proper functionalization pass is introduced into LTC
- func: as_strided_view_update(Tensor target, Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
- func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
opkind: ltc_as_strided_view_update
- func: as_strided(Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
- func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor
opkind: ltc_diagonal_view_update
cache_shape: false
- func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, std::vector<int64_t> base_indices) -> Tensor
properties:
- ShapeCompute
- func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor
opkind: ltc_narrow_view_update
- func: narrow(Tensor input, std::vector<int64_t> base_indices, std::vector<int64_t> sizes) -> Tensor
- func: permute(Tensor input, std::vector<int64_t> dims) -> Tensor
- func: resize(Tensor input, std::vector<int64_t> size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
- func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor
- func: permute(Tensor input, int[] dims) -> Tensor
- func: resize(Tensor input, int[] size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor
opkind: ltc_select_view_update
cache_shape: false
- func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
properties:
- ShapeCompute
- func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor
- func: squeeze(Tensor input, int dim) -> Tensor
- func: unsqueeze(Tensor input, int dim) -> Tensor

View File

@ -19,7 +19,7 @@ from datasets import load_dataset
from datasets.dataset_dict import DatasetDict
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, \
BertTokenizer, AdamW, get_scheduler
BertConfig, BertTokenizer, AdamW, get_scheduler
from typing import List
@ -70,7 +70,7 @@ def train(model: BertForSequenceClassification,
return losses
def main(device, lower_only):
def main(device, lower_only, full_size):
if device in ("TS", "MLIR_EXAMPLE"):
import torch._lazy
@ -95,8 +95,24 @@ def main(device, lower_only):
train_dataloader = DataLoader(small_train_dataset, shuffle=True,
batch_size=8)
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
num_labels=2)
if full_size:
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
num_labels=2)
else:
configuration = BertConfig(
vocab_size=28996,
hidden_size=32,
num_hidden_layers=1,
num_attention_heads=2,
intermediate_size=32,
hidden_act='gelu',
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
layer_norm_eps=1.0e-05,
)
model = BertForSequenceClassification(configuration)
model.to(device)
num_epochs = 3
@ -115,6 +131,8 @@ def main(device, lower_only):
if __name__ == "__main__":
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
@ -131,5 +149,12 @@ if __name__ == "__main__":
default=False,
help="Only get backend printout -- do not execute computation",
)
parser.add_argument(
"-f",
"--full_size",
action='store_true',
default=False,
help="Use full sized BERT model instead of one with smaller parameterization",
)
args = parser.parse_args()
main(args.device, args.lower_only)
main(args.device, args.lower_only, args.full_size)

View File

@ -73,6 +73,8 @@ def main(device):
if __name__ == "__main__":
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",

1
externals/pytorch vendored

@ -1 +0,0 @@
Subproject commit 9f3d6a00a76c567d7c046eabc60ae7a578f7bbde

View File

@ -18,6 +18,17 @@ include_directories(BEFORE
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
# Generate Lazy IR Nodes
execute_process(
COMMAND ${Python3_EXECUTABLE} build_tools/autogen_ltc_backend.py -f
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
)
add_custom_target(
generate_ltc_sources
ALL
COMMENT "Generating Lazy Tensor Core IR Nodes"
COMMAND ${Python3_EXECUTABLE} build_tools/autogen_ltc_backend.py
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
add_library(torch_mlir_ltc_backend SHARED
base_lazy_backend/backend_impl.cpp
@ -29,7 +40,10 @@ add_library(torch_mlir_ltc_backend SHARED
base_lazy_backend/mlir_native_functions.cpp
base_lazy_backend/mlir_node.cpp
base_lazy_backend/mlir_node_lowering.cpp
base_lazy_backend/ops/device_data.cpp
base_lazy_backend/ops/generic.cpp
)
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
add_dependencies(torch_mlir_ltc_backend
TorchMLIRJITIRImporter

View File

@ -7,271 +7,98 @@
//
//===----------------------------------------------------------------------===//
#include "LazyShapeInference.h"
#include "../utils/exception.h"
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#include <cmath>
#include "../utils/exception.h"
#include "LazyShapeInference.h"
namespace torch {
namespace lazy {
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future.
// Turns any negative index positive (assuming it's valid)
int64_t normalize_index(int64_t index, unsigned dims) {
return index < 0 ? (int64_t)dims + index : index;
std::vector<torch::lazy::Shape>
compute_shape_div(const at::Tensor& self, const at::Scalar & other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape>
compute_shape_dropout(const at::Tensor& input, double p, bool train) {
return {Shape(input.scalar_type(), input.sizes().vec())};
std::vector<torch::lazy::Shape>
compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_layer_norm(
const at::Tensor& input, at::IntArrayRef normalized_shape,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, double eps, bool cudnn_enable) {
return {Shape(input.scalar_type(), input.sizes().vec())};
}
std::vector<Shape>
compute_shape_matmul(const at::Tensor& self, const at::Tensor& other) {
std::vector<int64_t> sizes;
auto self_sizes = self.sizes().vec();
auto other_sizes = other.sizes().vec();
// For tensors with dimensions >2, the leading dimensions are for batch info.
// The last 2 (or 1 in the case of a single dim tensor) dimensions are the
// matrix dimensions themselves, which is checked to ensure the matmul op
// is legal.
//
// Example:
// [1, 2, 3, 4] -> [1, 2] batch dims and [3, 4] matrix
// [1, 4, 5] -> [1] batch dims and [4, 5] matrix
// [4, 5] -> [] batch dims and [4, 5] matrix
// [5] -> [] batch dims and [5] matrix
//
// We'll start by splitting the shapes as described above.
auto partition_shape = [](at::ArrayRef<int64_t> sizes) {
if (sizes.size() <= 2) {
return std::make_pair(
std::vector<int64_t>(),
std::vector<int64_t>(sizes.begin(), sizes.end()));
} else {
std::size_t partition_idx = sizes.size() - 2;
return std::make_pair(
std::vector<int64_t>(sizes.begin(), sizes.begin() + partition_idx),
std::vector<int64_t>(sizes.begin() + partition_idx, sizes.end()));
}
};
auto [self_batch_sizes, self_matrix_sizes] = partition_shape(self_sizes);
auto [other_batch_sizes, other_matrix_sizes] = partition_shape(other_sizes);
// Insert batch dimensions.
// The final list of sizes will be based on the tensor w/ more dims.
// Individual dimension sizes are "right justified" as we iterate thru
// to pick the larger dimension between them.
// 0 1 1 3 4
// 5 1 2
// ---------
// 0 1 5 3 4 <- Result
int64_t self_size, other_size;
std::size_t num_batch_dim =
std::max(self_batch_sizes.size(), other_batch_sizes.size());
auto get_batch_dim = [&](std::vector<int64_t> batch_sizes, std::size_t dim) {
long idx = dim - num_batch_dim + batch_sizes.size();
// Negative index means out of bounds, which defaults to a dim size of 1.
return idx < 0 ? 1 : batch_sizes[idx];
};
for (std::size_t i = 0; i < num_batch_dim; i++) {
self_size = get_batch_dim(self_batch_sizes, i);
other_size = get_batch_dim(other_batch_sizes, i);
TORCH_CHECK(
self_size == 1 || other_size == 1 || self_size == other_size,
"At trailing dimension ", i, ", expected for dimensions ",
"to either match or have one of them equal one, but got ", self_size,
" and ", other_size, " instead!");
sizes.push_back(std::max(self_size, other_size));
}
// Keep track of the inner dimensions of matmul to validate op is valid.
std::pair<int64_t, int64_t> inner_sizes;
if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 1) {
// Dot-Product -- scalar output, so no dimensions inserted
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
} else if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 2) {
// Vector-Matrix product (m) @ (m, n) -> (n)
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
sizes.push_back(other_matrix_sizes[1]);
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 1) {
// Matrix-Vector product (m, n) @ (n) -> (m)
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
sizes.push_back(self_matrix_sizes[0]);
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 2) {
// Matrix-Matrix product (m, n) @ (n, o) -> (m, o)
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
sizes.push_back(self_matrix_sizes[0]);
sizes.push_back(other_matrix_sizes[1]);
} else {
// By this time, self_matrix_sizes and other_matrix_sizes should have at
// most 2 dims, so if this is executed something has gone wrong...
TORCH_CHECK(false, "Invalid matmul shape combination!");
}
TORCH_CHECK(
inner_sizes.first == inner_sizes.second, "Inner dimension of matrix (",
inner_sizes.first, ") does not ", "match (", inner_sizes.second, ")!");
return {Shape(self.scalar_type(), sizes)};
}
std::vector<Shape> compute_shape_native_batch_norm(
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
std::vector<Shape> shapes;
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.sizes().vec()[1];
if (running_mean.has_value()) {
shapes.emplace_back(
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
}
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}
if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}
return shapes;
}
std::vector<Shape>
compute_shape_reshape(const at::Tensor& self, at::IntArrayRef shape) {
// Make a copy of the desired output shape.
std::vector<int64_t> sizes(shape.begin(), shape.end());
// Product of all sizes in input shape is the number of entries in tensor.
int64_t num_entries = 1;
for (int64_t i : self.sizes().vec()) {
num_entries *= i;
}
// Validate the number of entries in the desired shape. If there is a wildcard
// dimension, we need to find it now in order to populate it.
long wildcard_idx = -1;
int64_t num_concrete_entries = 1;
for (std::size_t idx = 0; idx < sizes.size(); idx++) {
if (sizes[idx] != -1) {
num_concrete_entries *= sizes[idx];
} else {
TORCH_CHECK(wildcard_idx == -1, "only one dimension can be inferred");
wildcard_idx = idx;
}
}
if (wildcard_idx == -1) {
// No wildcard, the shape should already be known.
TORCH_CHECK(
num_entries == num_concrete_entries, "shape `[", sizes,
"]` is invalid for input of size ", num_concrete_entries);
} else {
// There is one dimension which is not explicitly declared -- we need to
// infer.
TORCH_CHECK(
num_entries % num_concrete_entries == 0, "shape `[", sizes,
"]` is invalid for input of size ", num_concrete_entries);
sizes[wildcard_idx] = num_entries / num_concrete_entries;
}
return {Shape(self.scalar_type(), sizes)};
}
std::vector<Shape> compute_shape_rsub(
const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) {
// Since other is scalar, the result will match tensor shape.
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape>
compute_shape_select(const at::Tensor& self, int64_t dim, int64_t index) {
auto original_shape = self.sizes().vec();
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
::std::array<bool, 3> output_mask) {
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
dim < (int64_t)sizes.size(), "Dimension ", dim,
" is out of bounds for tensor with ", sizes.size(), " dimensions!");
TORCH_CHECK(
index < sizes[dim], "Index ", index,
" is out of bounds for dimension of size ", sizes[dim]);
sizes.erase(sizes.begin() + dim);
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.sizes().vec()[1];
return {Shape(self.scalar_type(), sizes)};
// `weight` and `bias` are vectors of length C (number of channels)`
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
return shapes;
}
std::vector<Shape> compute_shape_slice(
const at::Tensor& self, int64_t dim, c10::optional<int64_t> start,
c10::optional<int64_t> end, int64_t step) {
auto original_shape = self.sizes().vec();
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
int64_t dim_size = sizes[dim];
// Index may be negative, so we must normalize it.
int64_t start_norm = normalize_index(start.value(), dim_size);
int64_t end_norm = normalize_index(end.value(), dim_size);
if (start_norm >= end_norm || start_norm >= dim_size || end_norm <= 0) {
// Slice is out of bounds, nothing in range.
sizes[dim] = 0;
} else {
// Clamp upper and lower bound to valid indices.
start_norm = std::max((int64_t)0, start_norm);
end_norm = std::min(dim_size, end_norm);
// Final size is determined by step and interval size.
sizes[dim] = std::ceil((double)(end_norm - start_norm) / (double)step);
}
return {Shape(self.scalar_type(), sizes)};
}
std::vector<Shape> compute_shape_softmax(
const at::Tensor& self, int64_t dim, c10::optional<at::ScalarType> dtype) {
std::vector<torch::lazy::Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
if (dtype.has_value()) {
return {Shape(dtype.value(), self.sizes().vec())};
return {Shape(*dtype, size)};
}
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape>
compute_shape_transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) {
auto original_shape = self.sizes().vec();
std::vector<int64_t> sizes{original_shape.begin(), original_shape.end()};
// Index may be negative, so we must normalize it. We create new variables
// instead of replacing the existing ones so that in the case of an error,
// the original values can be printed out.
int64_t dim0_norm = normalize_index(dim0, sizes.size());
int64_t dim1_norm = normalize_index(dim1, sizes.size());
// Verify dimensions are valid.
TORCH_CHECK(
0 <= dim0_norm && dim0_norm < (int64_t)sizes.size(), "dim0 has value ",
dim0, ", but there are only ", sizes.size(), " tensor dimensions");
TORCH_CHECK(
0 <= dim1_norm && dim1_norm < (int64_t)sizes.size(), "dim1 has value ",
dim1, ", but there are only ", sizes.size(), " tensor dimensions");
// Swap shapes at dimensions.
std::swap(sizes[dim0_norm], sizes[dim1_norm]);
return {Shape(self.scalar_type(), sizes)};
return {Shape(self.scalar_type(), size)};
}
} // namespace lazy

View File

@ -22,74 +22,48 @@ namespace lazy {
// clang-format off
TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size);
TORCH_API std::vector<Shape> compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps, bool cudnn_enabled);
TORCH_API std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, const at::Tensor & p, c10::optional<at::Generator> generator);
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator);
TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, const c10::optional<at::Tensor> & weights, int64_t minlength);
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_dropout(const at::Tensor & input, double p, bool train);
TORCH_API std::vector<Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
TORCH_API std::vector<Shape> compute_shape_expand_as(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim);
TORCH_API std::vector<Shape> compute_shape_floor_divide(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_fmod(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_full_like(const at::Tensor & self, const at::Scalar & fill_value, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<Shape> compute_shape_hardswish(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val);
TORCH_API std::vector<Shape> compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
TORCH_API std::vector<Shape> compute_shape_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps, bool cudnn_enable);
TORCH_API std::vector<Shape> compute_shape_log_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim);
TORCH_API std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask);
TORCH_API std::vector<Shape> compute_shape_matmul(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_max(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode);
TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_relu_(at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
TORCH_API std::vector<Shape> compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape);
TORCH_API std::vector<Shape> compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index);
TORCH_API std::vector<Shape> compute_shape_slice(const at::Tensor & self, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
TORCH_API std::vector<Shape> compute_shape_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_square(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_std(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<Shape> compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_sub_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<Shape> compute_shape_zero_(at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<torch::lazy::Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_abs(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli_functional(const at::Tensor & self, const at::Tensor & p, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bincount(const at::Tensor & self, const c10::optional<at::Tensor> & weights, int64_t minlength);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_flip(const at::Tensor & self, at::IntArrayRef dims);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_fmod(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_logical_or(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, const c10::optional<at::Tensor> & save_mean, const c10::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional<bool> train);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_resize_functional(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_uniform_functional(const at::Tensor & self, double from, double to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
// clang-format on

View File

@ -20,6 +20,7 @@
#include "backend_impl.h"
#include "ir_builder.h"
#include "mlir_lowering_context.h"
#include "ops/device_data.h"
namespace torch {
namespace lazy {
@ -112,7 +113,7 @@ TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
if (!device_data_node) {
return nullptr;
}
return device_data_node->data;
return device_data_node->data();
}
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(

View File

@ -20,6 +20,7 @@
#include "dynamic_ir.h"
#include "generated/LazyNonNativeIr.h"
#include "mlir_node.h"
#include "ops/device_data.h"
#include "ops/generic.h"
// This file contains the TorchMlir IrBuilder
@ -35,7 +36,7 @@ struct TorchMlirIrBuilder : IrBuilder {
NodePtr MakeExpand(const Value& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand) const override { return MakeNode<Expand>(input0, size, is_scalar_expand); }
NodePtr MakeView(const Value& input0, const std::vector<int64_t>& output_size) const override { return MakeNode<View>(input0, output_size); }
NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype = c10::nullopt) const override { return MakeNode<Cast>(input0, dtype, stype); }
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TensorList>(inputs); }
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TorchMlirTensorList>(inputs); }
NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const override { return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed); }
// view ops

View File

@ -331,8 +331,12 @@ const std::string TorchMlirComputation::to_string() const {
ss << "Input/Output Alias Mapping: \n";
for (InputOutputAlias input_output_alias : input_output_aliases_) {
ss << "Output: " << input_output_alias.output_index
<< " -> Input param: " << input_output_alias.param_number << std::endl;
<< " -> Input param: " << input_output_alias.param_number << "\n";
}
ss << "\n";
// Mark Step
ss << "In Mark Step: " << (in_mark_step ? "true" : "false") << "\n";
return ss.str();
}

View File

@ -31,6 +31,7 @@
#include "../utils/sys_utils.h"
#include "LazyShapeInference.h"
#include "generated/LazyNativeFunctions.h"
#include "ops/to_copy.h"
namespace torch {
namespace lazy {
@ -166,6 +167,81 @@ torch::lazy::LazyTensorPtr create_view(
return input->CreateViewTensor(std::move(view_info));
}
torch::lazy::ViewInfo CreateAsStridedViewInfo(
const torch::lazy::Shape& input_shape, std::vector<int64_t> size,
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
torch::lazy::Shape result_shape =
torch::lazy::Shape(input_shape.scalar_type(), size);
torch::lazy::AsStridedInfo as_strided_info;
as_strided_info.stride = std::move(stride);
if (storage_offset) {
as_strided_info.offset = *storage_offset;
}
return torch::lazy::ViewInfo(
torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape),
input_shape, std::move(as_strided_info));
}
torch::lazy::LazyTensorPtr lazy_narrow(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t length) {
auto input_shape = input->shape();
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
torch::lazy::Shape narrow_shape = input_shape;
narrow_shape.set_size(dim, length);
torch::lazy::ViewInfo::Type view_type =
(input_shape.Get().numel() == narrow_shape.numel())
? torch::lazy::ViewInfo::Type::kReshape
: torch::lazy::ViewInfo::Type::kNarrow;
torch::lazy::ViewInfo view_info(
view_type, std::move(narrow_shape), input_shape);
view_info.indices[dim] =
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
return input->CreateViewTensor(std::move(view_info));
}
torch::lazy::LazyTensorPtr lazy_view(
const torch::lazy::LazyTensorPtr& input,
c10::ArrayRef<int64_t> output_size) {
auto input_shape = input->shape().Get();
torch::lazy::Shape shape = torch::lazy::Shape(
input_shape.scalar_type(),
at::infer_size(output_size, input_shape.numel()));
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
return input->CreateViewTensor(std::move(view_info));
}
torch::lazy::LazyTensorPtr lazy_select(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t index) {
auto shape = input->shape();
dim = torch::lazy::GetCanonicalDimensionIndex(dim, shape.Get().dim());
torch::lazy::LazyTensorPtr result = lazy_narrow(input, dim, index, 1);
auto new_dims = torch::lazy::DropDimensions(shape.Get().sizes(), {dim});
return lazy_view(result, new_dims);
}
torch::lazy::LazyTensorPtr lazy_slice(
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
int64_t end, int64_t step) {
auto input_shape = input->shape();
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
start =
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end);
// PyTorch allows tensor[-1:0] to return a 0-dim tensor.
if (start > end) {
end = start;
}
step = std::min(step, end - start);
torch::lazy::SelectInfo select = {dim, start, end, step};
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kSelect, input_shape, select);
return input->CreateViewTensor(std::move(view_info));
}
} // namespace
// at::Tensor LazyNativeFunctions::bernoulli(
@ -194,6 +270,44 @@ torch::lazy::LazyTensorPtr create_view(
// // return self;
// }
at::Tensor LazyNativeFunctions::as_strided(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto xsize = torch::lazy::ToI64Vector(size);
auto xstride = torch::lazy::ToI64Vector(stride);
if (!torch::lazy::StrideIsSupported(xstride)) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
return torch::lazy::CreateAtenFromLtcTensor(
self_tensor->CreateViewTensor(CreateAsStridedViewInfo(
self_tensor->shape(), std::move(xsize), std::move(xstride),
storage_offset)));
}
const at::Tensor& LazyNativeFunctions::as_strided_(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
auto xsize = torch::lazy::ToI64Vector(size);
auto xstride = torch::lazy::ToI64Vector(stride);
if (!torch::lazy::StrideIsSupported(xstride)) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
if (self_tensor->data()->view == nullptr) {
self_tensor->SetIrValue(torch::lazy::MakeAsStrided(
self_tensor->GetIrValue(), std::move(xsize), std::move(xstride),
storage_offset.value_or(0)));
} else {
auto input_shape = self_tensor->shape();
self_tensor->SetSubView(CreateAsStridedViewInfo(
input_shape, std::move(xsize), std::move(xstride), storage_offset));
}
return self;
}
at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto lazy_tensors = torch::lazy::GetLtcTensors(tensors);
@ -298,6 +412,99 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize(
return dst;
}
at::Tensor LazyNativeFunctions::_to_copy(
const at::Tensor& self, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory, bool non_blocking,
c10::optional<at::MemoryFormat> memory_format) {
PRINT_FUNCTION();
auto options = self.options();
if (dtype) {
// I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)...
// because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it
options = options.dtype(dtype);
}
if (layout) {
options = options.layout(layout);
}
if (memory_format) {
options = options.memory_format(memory_format);
}
if (pin_memory) {
// TODO(whc) can we honor 'pin_memory' in some/all cases?
options = options.pinned_memory(pin_memory);
TORCH_WARN_ONCE("Pinned memory used in lazy _to_copy, check if the "
"behavior is as intended");
}
TORCH_LAZY_FN_COUNTER("lazy::");
auto lazy_self = torch::lazy::TryGetLtcTensor(self);
if (!lazy_self && device && device->type() == c10::kLazy) {
// Case 1: eager->lazy (we create a new lazy tensor)
auto eager_tensor =
self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
lazy_self = torch::lazy::GetOrCreateLtcTensor(
eager_tensor, torch::lazy::atenDeviceToBackendDevice(*device));
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
} else if (device && device->type() != c10::kLazy) {
// Case 2: lazy->eager (forces a graph break since we are materializing a tensor)
TORCH_INTERNAL_ASSERT(lazy_self);
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
options = options.device(device);
auto moved_eager_tensor =
eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
return moved_eager_tensor;
} else if (
device && device->type() == c10::kLazy && device->has_index() &&
device->index() != self.device().index()) {
// Case 3: lazy:0 -> lazy:1
// TODO(whc) what do we actually want to do here?
// option 1: materialize, move eager tensor, create new lazy tensor
// - this should be our default, as it is what would happen before we implemented _to_copy
// - actually combines case 1 + case 2
// option 2: support multiple devices inside one lazy/TS executor (case 4)
// - but: we may have other assumptions that there is just one device per executor? so don't take this lightly
TORCH_INTERNAL_ASSERT(lazy_self);
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
// we move the eager tensor to the 'eager' equivalent of our lazy device
// e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use
auto eager_device = c10::Device(
torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index());
options = options.device(eager_device);
auto moved_eager_tensor =
eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true);
lazy_self = torch::lazy::GetOrCreateLtcTensor(
moved_eager_tensor,
torch::lazy::atenDeviceToBackendDevice(eager_device));
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
} else {
// Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph)
// Note: captured _to_copy will be executed with real eager tensors, not lazy tensors.
// We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to
// convert an eager tensor back to a lazy one inside the torchscript executor
// lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument
device = c10::nullopt;
auto shapes = torch::lazy::compute_shape__to_copy(
self, dtype, layout, device, pin_memory, non_blocking, memory_format);
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
auto node = torch::lazy::MakeNode<ToCopy>(
lazy_self->GetIrValue(), dtype, layout, device, pin_memory,
non_blocking, memory_format, std::move(shapes));
auto result =
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
std::move(node), lazy_self->GetDevice()));
return result;
}
};
at::Tensor LazyNativeFunctions::empty(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
@ -313,6 +520,15 @@ at::Tensor LazyNativeFunctions::empty(
return CreateLtcTensor(x_result, GetLtcDevice(device));
}
at::Tensor LazyNativeFunctions::empty_strided(
at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt);
return LazyNativeFunctions::as_strided(t, size, stride, /*storage_offset=*/0);
}
at::Tensor LazyNativeFunctions::expand(
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
TORCH_LAZY_FN_COUNTER("lazy::");
@ -355,6 +571,23 @@ LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
self_tensor->CreateViewTensor(std::move(view_info)));
}
at::Tensor LazyNativeFunctions::select(
const at::Tensor& self, int64_t dim, int64_t index) {
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(
lazy_select(torch::lazy::TryGetLtcTensor(self), dim, index));
}
at::Tensor LazyNativeFunctions::slice(
const at::Tensor& self, int64_t dim, c10::optional<int64_t> start,
c10::optional<int64_t> end, int64_t step) {
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
TORCH_LAZY_FN_COUNTER("lazy::");
return torch::lazy::CreateAtenFromLtcTensor(lazy_slice(
torch::lazy::TryGetLtcTensor(self), dim, start_val, end_val, step));
}
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
return squeeze(self, -1);
}
@ -390,6 +623,21 @@ at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
self_tensor->CreateViewTensor(std::move(view_info)));
}
at::Tensor LazyNativeFunctions::transpose(
const at::Tensor& self, int64_t dim0, int64_t dim1) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto input_shape = self_tensor->shape();
auto permute_dims = torch::lazy::MakeTransposePermutation(
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
return torch::lazy::CreateAtenFromLtcTensor(
self_tensor->CreateViewTensor(std::move(view_info)));
}
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
@ -418,6 +666,21 @@ LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
self_tensor->CreateViewTensor(std::move(view_info)));
}
at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self, at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
auto input_shape = self_tensor->shape().Get();
torch::lazy::Shape shape = torch::lazy::Shape(
input_shape.scalar_type(),
at::infer_size(torch::lazy::ToI64Vector(size), input_shape.numel()));
torch::lazy::ViewInfo view_info(
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
return torch::lazy::CreateAtenFromLtcTensor(
self_tensor->CreateViewTensor(std::move(view_info)));
}
void InitializeAtenBindings() {}
} // namespace lazy

View File

@ -76,15 +76,23 @@ TorchMlirOpVector TorchMlirNode::Lower(
return {};
}
TensorList::TensorList(OpList values)
OpKind TorchMlirTensorList::ClassOpKind() {
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise
static const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
return tensor_list_opkind;
}
TorchMlirTensorList::TorchMlirTensorList(OpList values)
: TorchMlirNode(
/*op=*/tensor_list_opkind,
/*op=*/TorchMlirTensorList::ClassOpKind(),
/*operands=*/values,
/*shapes=*/std::vector<Shape>(),
/*num_outputs=*/1,
/*hash_seed=*/kHashSeed) {}
torch::lazy::TorchMlirOpVector TensorList::Lower(
torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::Value*> tensor_list;
CHECK(!operands().empty());

View File

@ -42,6 +42,8 @@ public:
TorchMlirNode(
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed);
~TorchMlirNode() override = default;
hash_t hash() const override;
hash_t shapeHash() const override;
@ -58,9 +60,6 @@ private:
hash_t dag_hash_;
};
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise
const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
// TensorList represents an at::TensorList which is a vector[Tensor] but is also
// a first-class IValue and can be fed as a single input to a TS program. It is
@ -77,9 +76,11 @@ const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and
// then implement it as NotImplemented for TensorList, also fixing the assertion
// that would fail.
struct TORCH_API TensorList : public TorchMlirNode {
TensorList() = delete;
TensorList(OpList values);
struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
static OpKind ClassOpKind();
TorchMlirTensorList() = delete;
TorchMlirTensorList(OpList values);
torch::lazy::TorchMlirOpVector Lower(
TorchMlirFunction function,

View File

@ -14,6 +14,7 @@
#include "generated/LazyNonNativeIr.h"
#include "mlir_lowering_context.h"
#include "mlir_node.h"
#include "ops/device_data.h"
#include <ATen/Functions.h>
#include <c10/core/ScalarType.h>
@ -61,8 +62,8 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
for (jit::Value* value : results) {
if (value->type()->kind() == c10::TypeKind::TensorType) {
TORCH_CHECK(
tensor_type_idx < tensor_types.size(),
"Tensor corresponding to JIT SSA value %", value->debugName(),
tensor_type_idx < tensor_types.size(), function->graph()->toString(),
"\nTensor corresponding to JIT SSA value %", value->debugName(),
" corresponds to result #", tensor_type_idx, ", but we only have ",
tensor_types.size(), " known types!");
@ -100,6 +101,79 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
function, sym, tensor_types, arguments, kwarguments);
}
c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
auto tensor_type = value_type->cast<c10::TensorType>();
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
return *tensor_type.get();
}
c10::optional<std::vector<int64_t>>
get_tensor_type_shape(c10::TensorType& tensor_type) {
auto& symbolic_shape = tensor_type.symbolic_sizes();
if (!symbolic_shape.rank()) {
return c10::nullopt;
}
// Get current tensor shape.
std::vector<int64_t> dims;
dims.resize(*symbolic_shape.rank());
for (size_t i = 0; i < dims.size(); ++i) {
auto shape_symbol = symbolic_shape[i];
dims[i] = shape_symbol.is_static() ? shape_symbol.static_size() : -1;
}
return dims;
}
std::vector<torch::lazy::Shape> compute_shape_copy(c10::TypePtr value_type) {
c10::TensorType& tensor_type = cast_tensor_type(value_type);
auto maybe_dims = get_tensor_type_shape(tensor_type);
TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!");
auto scalar_type = tensor_type.scalarType();
TORCH_CHECK(
scalar_type.has_value(), "Unable to copy due to lack of scalar type!");
return {Shape(scalar_type.value(), maybe_dims.value())};
}
std::vector<torch::lazy::Shape> compute_shape_slice(
c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end,
int64_t step) {
c10::TensorType& tensor_type = cast_tensor_type(value_type);
auto maybe_dims = get_tensor_type_shape(tensor_type);
TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!");
std::vector<int64_t> dims = maybe_dims.value();
int64_t num_dims = dims[dim];
// Index may be negative, so we must normalize it.
auto normalize_index = [](int64_t index, unsigned num_dims) {
return index < 0 ? (int64_t)num_dims + index : index;
};
start = normalize_index(start, num_dims);
end = normalize_index(end, num_dims);
if (start >= end || start >= num_dims || end <= 0) {
// Slice is out of bounds, nothing in range.
dims[dim] = 0;
} else {
// Clamp upper and lower bound to valid indices.
start = std::max((int64_t)0, start);
end = std::min(num_dims, end);
// Final size is determined by step and interval size.
dims[dim] = std::ceil((double)(end - start) / (double)step);
}
auto scalar_type = tensor_type.scalarType();
TORCH_CHECK(
scalar_type.has_value(), "Unable to slice due to lack of scalar type!");
return {Shape(scalar_type.value(), dims)};
}
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
public:
TorchMlirNodeLowering(
@ -213,14 +287,14 @@ public:
const torch::lazy::DeviceData* device_data_node =
torch::lazy::NodeCast<torch::lazy::DeviceData>(
node, *torch::lazy::ltc_device_data);
auto infoptr = device_data_node->data->info();
auto infoptr = device_data_node->data()->info();
auto deviceDataInfoPtr =
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
if (GRAPH_DUMP_ENABLED) {
LOG(ERROR) << "Lowering device data node, tensor id "
<< deviceDataInfoPtr->tensor_id << std::endl;
}
return {loctx()->GetParameter(device_data_node->data)};
return {loctx()->GetParameter(device_data_node->data())};
}
std::vector<torch::jit::NamedValue> arguments;
@ -421,8 +495,8 @@ public:
arguments.emplace_back(destination);
arguments.emplace_back(source);
LowerBuiltin(
at::aten::copy_, c10::ArrayRef<Shape>({/*shape goes here*/}),
arguments);
at::aten::copy_,
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), arguments);
}
torch::jit::Value* GenerateSlice(
@ -434,8 +508,11 @@ public:
arguments.emplace_back(start);
arguments.emplace_back(end);
arguments.emplace_back(step);
TorchMlirOpVector selected = LowerBuiltin(
at::aten::slice, c10::ArrayRef<Shape>({/*shape goes here*/}),
at::aten::slice,
c10::ArrayRef<Shape>(
compute_shape_slice(base->type(), dim, start, end, step)),
arguments);
CHECK_EQ(selected.size(), 1);
return selected.front();

View File

@ -0,0 +1,41 @@
#include <sstream>
#include <torch/csrc/lazy/core/ir_builder.h>
#include "device_data.h"
namespace torch {
namespace lazy {
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
: TorchMlirNode(
ClassOpKind(),
data->shape(),
/*num_outputs=*/1,
/*hash_seed=*/static_cast<uint32_t>(101)),
data_(std::move(data)) {}
std::string DeviceData::ToString() const {
std::stringstream ss;
ss << TorchMlirNode::ToString() << ", device=" << data_->device();
return ss.str();
}
const DeviceData* DeviceData::Cast(const Node* node) {
return NodeCast<DeviceData>(node);
}
NodePtr DeviceData::Create(std::shared_ptr<BackendData> data) {
NodePtr node = ReuseOrMakeNode<DeviceData>(data);
// ReuseOrMakeNode may return a reused node which has the same shape,
// however, we need to replace the old data_ with the new one.
// Ditching the old data_ is safe because tracing is done iteration
// by iteration, and after we lauch the async device execution for the
// previous iteration, data_ in DeviceData nodes are not needed anymore.
DeviceData* device_data = static_cast<DeviceData*>(node.get());
device_data->SetData(data);
return node;
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,48 @@
#pragma once
#include "../mlir_node.h"
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
class TORCH_API DeviceData : public TorchMlirNode {
public:
static OpKind ClassOpKind() {
return ltc_device_data;
}
explicit DeviceData(std::shared_ptr<BackendData> data);
// A DeviceData node can be reused if the shape matches,
// but we will substitute the actual data_ pointer under
// the hood.
bool CanBeReused(std::shared_ptr<BackendData> data) const {
return data_->shape() == data->shape();
}
std::string ToString() const override;
const std::shared_ptr<BackendData>& data() const {
return data_;
}
void SetData(std::shared_ptr<BackendData> data) {
data_ = data;
}
static const DeviceData* Cast(const Node* node);
// To reuse IR nodes, use this method to create DeviceData nodes
// instead of calling the constructor directly.
static NodePtr Create(std::shared_ptr<BackendData> data);
private:
std::shared_ptr<BackendData> data_;
};
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,101 @@
//===- to_copy.h ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// this file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ops/to_copy.h
//===----------------------------------------------------------------------===//
#pragma once
#include "../mlir_node.h"
namespace torch {
namespace lazy {
// This IR was copied from code-generated output, but the entire _to_copy operator
// cannot be trivially code genereated since it is only desirable to capture IR for
// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke
// the aten/eager fallback necessitating directly implementing the right to(device) behavior
class ToCopy : public torch::lazy::TorchMlirNode {
public:
ToCopy(const torch::lazy::Value& self, const c10::optional<at::ScalarType>& dtype, const c10::optional<at::Layout>& layout, const c10::optional<at::Device>& device, const c10::optional<bool>& pin_memory, const bool& non_blocking, const c10::optional<at::MemoryFormat>& memory_format, std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy),
{self}, std::move(shapes),
/* num_outputs */ 1,
torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)),
dtype(dtype),
layout(layout),
device(device),
pin_memory(pin_memory),
non_blocking(non_blocking),
memory_format(memory_format) {}
std::string ToString() const override {
std::stringstream ss;
ss << torch::lazy::TorchMlirNode::ToString();
if (dtype.has_value()) {
ss << ", dtype=" << dtype.value();
} else {
ss << ", dtype=null";
}
if (layout.has_value()) {
ss << ", layout=" << layout.value();
} else {
ss << ", layout=null";
}
if (device.has_value()) {
ss << ", device=" << device.value();
} else {
ss << ", device=null";
}
if (pin_memory.has_value()) {
ss << ", pin_memory=" << pin_memory.value();
} else {
ss << ", pin_memory=null";
}
ss << ", non_blocking=" << non_blocking;
if (memory_format.has_value()) {
ss << ", memory_format=" << memory_format.value();
} else {
ss << ", memory_format=null";
}
return ss.str();
}
torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function,
torch::lazy::TorchMlirLoweringContext* loctx) const override {
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve(1);
kwarguments.reserve(6);
size_t i = 0;
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
kwarguments.emplace_back("dtype", dtype);
kwarguments.emplace_back("layout", layout);
kwarguments.emplace_back("device", device);
kwarguments.emplace_back("pin_memory", pin_memory);
kwarguments.emplace_back("non_blocking", non_blocking);
kwarguments.emplace_back("memory_format", memory_format);
torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ(_to_copy_out.size(), 1);
return _to_copy_out;
}
c10::optional<at::ScalarType> dtype;
c10::optional<at::Layout> layout;
c10::optional<at::Device> device;
c10::optional<bool> pin_memory;
bool non_blocking;
c10::optional<at::MemoryFormat> memory_format;
};
} // namespace lazy
} // namespace torch

View File

@ -133,7 +133,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
auto containedTypes = c10::fmap(
node->output()->type()->cast<c10::TupleType>()->containedTypes(),
[&](const c10::TypePtr &t) {
MlirType type = getMlirTypeFromTorchType(loc, t);
MlirType type = getMlirTypeFromTorchType(loc, t, importOptions);
if (mlirTypeIsNull(type)) {
throw mlir_diagnostic_emitted();
}