mirror of https://github.com/llvm/torch-mlir
E2E HuggingFace Bert using LTC Backend (#912)
* Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>pull/1125/head
parent
0c35e607b3
commit
d9aee0d7a7
|
@ -11,6 +11,7 @@ libtorch*
|
|||
|
||||
/build/
|
||||
__pycache__
|
||||
*.pyc
|
||||
|
||||
.pytype
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import argparse
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
|
@ -11,18 +13,16 @@ from textwrap import dedent
|
|||
|
||||
import yaml
|
||||
|
||||
TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
|
||||
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")
|
||||
|
||||
sys.path.append(str(TORCH_DIR))
|
||||
|
||||
# PyTorch's LTC backend autogen script
|
||||
import torchgen
|
||||
import torchgen.dest.lazy_ir
|
||||
import torchgen.gen_lazy_tensor
|
||||
from torchgen.api.lazy import LazyIrSchema
|
||||
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
|
||||
from torchgen.model import NativeFunctionsGroup
|
||||
|
||||
TORCH_DIR = Path(importlib.util.find_spec('torch').origin).resolve().parent.parent
|
||||
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
def isOptionalCType(arg):
|
||||
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"
|
||||
|
@ -42,20 +42,29 @@ def generate_native_functions(
|
|||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
|
||||
def get_native_function_name(f):
|
||||
func = f.func if hasattr(f, "func") else f.functional.func
|
||||
return str(func.name)
|
||||
func = f if hasattr(f, "func") else f.functional
|
||||
return str(func.func.name), func
|
||||
|
||||
aten_funcs = set(map(get_native_function_name, grouped_native_functions))
|
||||
def get_opnames(ops):
|
||||
opnames = defaultdict(set)
|
||||
for op in ops:
|
||||
opname = op.split(".")[0]
|
||||
opnames[opname].add(op)
|
||||
return opnames
|
||||
|
||||
native_functions = dict(map(get_native_function_name, native_functions))
|
||||
grouped_native_functions = dict(map(get_native_function_name, grouped_native_functions))
|
||||
aten_funcs = get_opnames(set(grouped_native_functions.keys()))
|
||||
|
||||
with config_path.open() as f:
|
||||
config = yaml.load(f, yaml.CLoader)
|
||||
|
||||
# List of unsupported ops in LTC autogen because of some error
|
||||
blacklist = config.get("blacklist", [])
|
||||
blacklist = set(config.get("blacklist", []))
|
||||
|
||||
# List of supported ops that we don't want to do the full codegen for
|
||||
# primarily view ops
|
||||
supported = config.get("supported", [])
|
||||
supported = set(config.get("supported", []))
|
||||
|
||||
# List of non-native ops to do IR codegen for
|
||||
non_native = config.get("non_native", [])
|
||||
|
@ -65,49 +74,54 @@ def generate_native_functions(
|
|||
else:
|
||||
cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"]
|
||||
|
||||
output = (
|
||||
subprocess.check_output(
|
||||
torch_ops = set(
|
||||
op[6:]
|
||||
for op in subprocess.check_output(
|
||||
cmd + [str(torch_ops_file)],
|
||||
encoding="utf-8",
|
||||
)
|
||||
.strip()
|
||||
.split(os.linesep)
|
||||
)
|
||||
torch_opnames = get_opnames(torch_ops)
|
||||
|
||||
# process ops list
|
||||
ops = []
|
||||
supported_ops = []
|
||||
skipped = []
|
||||
ops = set()
|
||||
composite_implicit = set()
|
||||
|
||||
for op in output:
|
||||
op = op[6:]
|
||||
opname = op.split(".")[0]
|
||||
|
||||
if opname in blacklist or op in blacklist:
|
||||
for op in torch_ops:
|
||||
if op not in native_functions:
|
||||
continue
|
||||
|
||||
if opname in supported:
|
||||
supported_ops.append(op)
|
||||
func = native_functions[op]
|
||||
base = func.func.name.name.base
|
||||
|
||||
if base in blacklist or op in blacklist:
|
||||
continue
|
||||
if base in supported or op in supported:
|
||||
continue
|
||||
|
||||
if op not in aten_funcs:
|
||||
skipped.append(op)
|
||||
continue
|
||||
if func.has_composite_implicit_autograd_kernel and f"{op}_backward" not in torch_ops:
|
||||
composite_implicit.add(op)
|
||||
elif func.func.name.name.inplace:
|
||||
for autogen in func.autogen:
|
||||
if "functional" in autogen.overload_name:
|
||||
ops.add(str(autogen))
|
||||
else:
|
||||
ops.add(op)
|
||||
|
||||
ops.append(op)
|
||||
|
||||
opnames = sorted(set(ops))
|
||||
skipped = set(torch_ops) - ops - supported - composite_implicit
|
||||
|
||||
# Additional ops to support that are not supported by Torch-MLIR explicitly
|
||||
supported_ops.extend(config.get("additional_ops", []))
|
||||
supported |= set(config.get("additional_ops", []))
|
||||
|
||||
with out_file.open("w") as f:
|
||||
yaml.dump(
|
||||
{
|
||||
"backend": "Lazy",
|
||||
"cpp_namespace": "torch::lazy",
|
||||
"full_codegen": opnames,
|
||||
"supported": sorted(supported_ops),
|
||||
"full_codegen": sorted(ops),
|
||||
"supported": sorted(supported),
|
||||
"non_native": non_native,
|
||||
},
|
||||
f,
|
||||
|
@ -117,10 +131,15 @@ def generate_native_functions(
|
|||
dedent(
|
||||
"""
|
||||
|
||||
# Composite implicit ops (supported by Torch-MLIR but not differentiable)
|
||||
{composite_implicit}
|
||||
# Skipped ops (supported by Torch-MLIR but no equivalent native function)
|
||||
{skipped}
|
||||
"""
|
||||
).format(
|
||||
composite_implicit=os.linesep.join(f"# - {op}" for op in sorted(composite_implicit)),
|
||||
skipped=os.linesep.join(f"# - {op}" for op in sorted(skipped)),
|
||||
)
|
||||
+ os.linesep.join(f"# - {op}" for op in sorted(skipped))
|
||||
)
|
||||
|
||||
return parsed_yaml, grouped_native_functions
|
||||
|
@ -129,11 +148,13 @@ def generate_native_functions(
|
|||
@dataclass(frozen=True)
|
||||
class GenMlirLazyIr(torchgen.dest.GenLazyIR):
|
||||
|
||||
def lowering_function(self, schema, declaration_only=True):
|
||||
def lowering_function(self, schema):
|
||||
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"
|
||||
|
||||
if declaration_only:
|
||||
if schema.properties.LowerDeclOnly:
|
||||
return f"{signature};"
|
||||
elif not schema.properties.Lower:
|
||||
return ""
|
||||
|
||||
emplace_arguments = []
|
||||
for arg in schema.positional_args:
|
||||
|
@ -213,7 +234,7 @@ def generate_backend(
|
|||
import re
|
||||
|
||||
sig_re = re.compile(
|
||||
r"std::vector<Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
|
||||
r"std::vector<torch::lazy::Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
|
||||
)
|
||||
global_signatures = {}
|
||||
|
||||
|
@ -307,25 +328,30 @@ def main(args):
|
|||
)
|
||||
assert backend_path.is_dir()
|
||||
|
||||
torchgen_path = Path(torchgen.__path__[0]).resolve()
|
||||
assert torchgen_path.is_dir()
|
||||
|
||||
prev_hash = None
|
||||
hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash")
|
||||
if hash_file.exists():
|
||||
prev_hash = hash_file.read_text().strip()
|
||||
|
||||
m = hashlib.sha256()
|
||||
m.update(script_path.read_bytes())
|
||||
m.update(config_path.read_bytes())
|
||||
m.update(torch_ops_file.read_bytes())
|
||||
if native_functions.exists():
|
||||
m.update(native_functions.read_bytes())
|
||||
|
||||
shape_inference_headers = backend_path.joinpath("LazyShapeInference.h")
|
||||
if shape_inference_headers.exists():
|
||||
m.update(shape_inference_headers.read_bytes())
|
||||
|
||||
shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp")
|
||||
if shape_inference_defs.exists():
|
||||
m.update(shape_inference_defs.read_bytes())
|
||||
# Add file contents to hash
|
||||
for path in (
|
||||
script_path,
|
||||
config_path,
|
||||
torch_ops_file,
|
||||
native_functions,
|
||||
backend_path.joinpath("LazyShapeInference.h"),
|
||||
backend_path.joinpath("LazyShapeInference.cpp"),
|
||||
torchgen_path.joinpath("dest", "lazy_ir.py"),
|
||||
torchgen_path.joinpath("api", "lazy.py"),
|
||||
torchgen_path.joinpath("model.py"),
|
||||
):
|
||||
if path.exists():
|
||||
m.update(path.read_bytes())
|
||||
|
||||
new_hash = m.hexdigest().strip()
|
||||
|
||||
|
|
|
@ -1,22 +1,11 @@
|
|||
blacklist:
|
||||
# List of unsupported ops in LTC autogen because of some error
|
||||
- arange # Error: Code below assumes there is at least one tensor arg
|
||||
- contiguous # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- full # Error: Code below assumes there is at least one tensor arg
|
||||
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
|
||||
- index_put # Error: TODO not sure if there are other valid types to handle here
|
||||
- index_put_ # Error: TODO not sure if there are other valid types to handle here
|
||||
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
|
||||
- ones # Error: Code below assumes there is at least one tensor arg
|
||||
- ones_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- resize_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- stack # Error: TODO not sure if there are other valid types to handle here
|
||||
- to.dtype # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- to.other # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- uniform_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- zeros # Error: Code below assumes there is at least one tensor arg
|
||||
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
|
||||
# Additional ops which autogen is supported for but don't compile yet
|
||||
- detach
|
||||
|
@ -24,26 +13,34 @@ blacklist:
|
|||
- size
|
||||
- where
|
||||
- copy_
|
||||
- _to_copy
|
||||
- log_softmax # Not inherently differentiable. Needs to be decomposed.
|
||||
- linear # Not inherently differentiable. Needs to be decomposed.
|
||||
|
||||
# Disabled for consistency with TS backend
|
||||
- rsub
|
||||
|
||||
# List of supported ops that we don't want to do the full codegen for
|
||||
# primarily view ops
|
||||
supported:
|
||||
# - bernoulli
|
||||
# - bernoulli_
|
||||
- as_strided
|
||||
- as_strided_
|
||||
- _to_copy
|
||||
- cat
|
||||
- clone
|
||||
- empty
|
||||
- empty.memory_format
|
||||
- empty_strided
|
||||
- expand
|
||||
- fill_
|
||||
- native_batch_norm_backward
|
||||
- fill_.Scalar
|
||||
- permute
|
||||
- select.int
|
||||
- slice.Tensor
|
||||
- squeeze
|
||||
- squeeze.dim
|
||||
- t
|
||||
- transpose.int
|
||||
- unsqueeze
|
||||
- view
|
||||
- _unsafe_view
|
||||
|
||||
additional_ops:
|
||||
# Additional ops to support that are not supported by Torch-MLIR explicitly
|
||||
|
@ -53,35 +50,38 @@ additional_ops:
|
|||
|
||||
# List of non native ops that we only want to do IR node class generation for
|
||||
non_native:
|
||||
- func: device_data(std::shared_ptr<BackendData> data) -> Tensor
|
||||
opkind: ltc_device_data
|
||||
cache_shape: false
|
||||
- func: scalar(at::Scalar value, at::ScalarType type) -> Tensor
|
||||
- func: scalar(Scalar value, ScalarType type) -> Tensor
|
||||
opkind: at::prim::Constant
|
||||
cache_shape: false
|
||||
- func: expand(Tensor input, std::vector<int64_t> size, bool is_scalar_expand) -> Tensor
|
||||
- func: view(Tensor input, std::vector<int64_t> output_size) -> Tensor
|
||||
cache_shape: false
|
||||
- func: cast(Tensor input, at::ScalarType dtype, optional<at::ScalarType> stype) -> Tensor
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- TreatScalarsAsConstants
|
||||
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
|
||||
- func: view(Tensor input, int[] output_size) -> Tensor
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
|
||||
opkind: ltc_cast
|
||||
cache_shape: false
|
||||
properties:
|
||||
- ShapeCompute
|
||||
|
||||
# View ops only required until proper functionalization pass is introduced into LTC
|
||||
- func: as_strided_view_update(Tensor target, Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
|
||||
- func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
|
||||
opkind: ltc_as_strided_view_update
|
||||
- func: as_strided(Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
|
||||
- func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
|
||||
- func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
|
||||
- func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor
|
||||
opkind: ltc_diagonal_view_update
|
||||
cache_shape: false
|
||||
- func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
|
||||
- func: narrow_view_update(Tensor input, Tensor source, std::vector<int64_t> base_indices) -> Tensor
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor
|
||||
- func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor
|
||||
opkind: ltc_narrow_view_update
|
||||
- func: narrow(Tensor input, std::vector<int64_t> base_indices, std::vector<int64_t> sizes) -> Tensor
|
||||
- func: permute(Tensor input, std::vector<int64_t> dims) -> Tensor
|
||||
- func: resize(Tensor input, std::vector<int64_t> size) -> Tensor
|
||||
- func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
|
||||
- func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor
|
||||
- func: permute(Tensor input, int[] dims) -> Tensor
|
||||
- func: resize(Tensor input, int[] size) -> Tensor
|
||||
- func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor
|
||||
opkind: ltc_select_view_update
|
||||
cache_shape: false
|
||||
- func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
|
||||
properties:
|
||||
- ShapeCompute
|
||||
- func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor
|
||||
- func: squeeze(Tensor input, int dim) -> Tensor
|
||||
- func: unsqueeze(Tensor input, int dim) -> Tensor
|
||||
|
|
|
@ -19,7 +19,7 @@ from datasets import load_dataset
|
|||
from datasets.dataset_dict import DatasetDict
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BertForSequenceClassification, \
|
||||
BertTokenizer, AdamW, get_scheduler
|
||||
BertConfig, BertTokenizer, AdamW, get_scheduler
|
||||
from typing import List
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ def train(model: BertForSequenceClassification,
|
|||
return losses
|
||||
|
||||
|
||||
def main(device, lower_only):
|
||||
def main(device, lower_only, full_size):
|
||||
if device in ("TS", "MLIR_EXAMPLE"):
|
||||
import torch._lazy
|
||||
|
||||
|
@ -95,8 +95,24 @@ def main(device, lower_only):
|
|||
|
||||
train_dataloader = DataLoader(small_train_dataset, shuffle=True,
|
||||
batch_size=8)
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
|
||||
num_labels=2)
|
||||
if full_size:
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
|
||||
num_labels=2)
|
||||
else:
|
||||
configuration = BertConfig(
|
||||
vocab_size=28996,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=2,
|
||||
intermediate_size=32,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
layer_norm_eps=1.0e-05,
|
||||
)
|
||||
model = BertForSequenceClassification(configuration)
|
||||
|
||||
model.to(device)
|
||||
|
||||
num_epochs = 3
|
||||
|
@ -115,6 +131,8 @@ def main(device, lower_only):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
|
@ -131,5 +149,12 @@ if __name__ == "__main__":
|
|||
default=False,
|
||||
help="Only get backend printout -- do not execute computation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--full_size",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use full sized BERT model instead of one with smaller parameterization",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.device, args.lower_only)
|
||||
main(args.device, args.lower_only, args.full_size)
|
||||
|
|
|
@ -73,6 +73,8 @@ def main(device):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Subproject commit 9f3d6a00a76c567d7c046eabc60ae7a578f7bbde
|
|
@ -18,6 +18,17 @@ include_directories(BEFORE
|
|||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
# Generate Lazy IR Nodes
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} build_tools/autogen_ltc_backend.py -f
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
|
||||
)
|
||||
add_custom_target(
|
||||
generate_ltc_sources
|
||||
ALL
|
||||
COMMENT "Generating Lazy Tensor Core IR Nodes"
|
||||
COMMAND ${Python3_EXECUTABLE} build_tools/autogen_ltc_backend.py
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
|
||||
add_library(torch_mlir_ltc_backend SHARED
|
||||
base_lazy_backend/backend_impl.cpp
|
||||
|
@ -29,7 +40,10 @@ add_library(torch_mlir_ltc_backend SHARED
|
|||
base_lazy_backend/mlir_native_functions.cpp
|
||||
base_lazy_backend/mlir_node.cpp
|
||||
base_lazy_backend/mlir_node_lowering.cpp
|
||||
base_lazy_backend/ops/device_data.cpp
|
||||
base_lazy_backend/ops/generic.cpp
|
||||
)
|
||||
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
||||
|
||||
add_dependencies(torch_mlir_ltc_backend
|
||||
TorchMLIRJITIRImporter
|
||||
|
|
|
@ -7,271 +7,98 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "LazyShapeInference.h"
|
||||
#include "../utils/exception.h"
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "../utils/exception.h"
|
||||
#include "LazyShapeInference.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future.
|
||||
|
||||
// Turns any negative index positive (assuming it's valid)
|
||||
int64_t normalize_index(int64_t index, unsigned dims) {
|
||||
return index < 0 ? (int64_t)dims + index : index;
|
||||
std::vector<torch::lazy::Shape>
|
||||
compute_shape_div(const at::Tensor& self, const at::Scalar & other) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape>
|
||||
compute_shape_dropout(const at::Tensor& input, double p, bool train) {
|
||||
return {Shape(input.scalar_type(), input.sizes().vec())};
|
||||
std::vector<torch::lazy::Shape>
|
||||
compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_layer_norm(
|
||||
const at::Tensor& input, at::IntArrayRef normalized_shape,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias, double eps, bool cudnn_enable) {
|
||||
return {Shape(input.scalar_type(), input.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape>
|
||||
compute_shape_matmul(const at::Tensor& self, const at::Tensor& other) {
|
||||
std::vector<int64_t> sizes;
|
||||
|
||||
auto self_sizes = self.sizes().vec();
|
||||
auto other_sizes = other.sizes().vec();
|
||||
|
||||
// For tensors with dimensions >2, the leading dimensions are for batch info.
|
||||
// The last 2 (or 1 in the case of a single dim tensor) dimensions are the
|
||||
// matrix dimensions themselves, which is checked to ensure the matmul op
|
||||
// is legal.
|
||||
//
|
||||
// Example:
|
||||
// [1, 2, 3, 4] -> [1, 2] batch dims and [3, 4] matrix
|
||||
// [1, 4, 5] -> [1] batch dims and [4, 5] matrix
|
||||
// [4, 5] -> [] batch dims and [4, 5] matrix
|
||||
// [5] -> [] batch dims and [5] matrix
|
||||
//
|
||||
// We'll start by splitting the shapes as described above.
|
||||
auto partition_shape = [](at::ArrayRef<int64_t> sizes) {
|
||||
if (sizes.size() <= 2) {
|
||||
return std::make_pair(
|
||||
std::vector<int64_t>(),
|
||||
std::vector<int64_t>(sizes.begin(), sizes.end()));
|
||||
} else {
|
||||
std::size_t partition_idx = sizes.size() - 2;
|
||||
return std::make_pair(
|
||||
std::vector<int64_t>(sizes.begin(), sizes.begin() + partition_idx),
|
||||
std::vector<int64_t>(sizes.begin() + partition_idx, sizes.end()));
|
||||
}
|
||||
};
|
||||
auto [self_batch_sizes, self_matrix_sizes] = partition_shape(self_sizes);
|
||||
auto [other_batch_sizes, other_matrix_sizes] = partition_shape(other_sizes);
|
||||
|
||||
// Insert batch dimensions.
|
||||
// The final list of sizes will be based on the tensor w/ more dims.
|
||||
// Individual dimension sizes are "right justified" as we iterate thru
|
||||
// to pick the larger dimension between them.
|
||||
// 0 1 1 3 4
|
||||
// 5 1 2
|
||||
// ---------
|
||||
// 0 1 5 3 4 <- Result
|
||||
int64_t self_size, other_size;
|
||||
std::size_t num_batch_dim =
|
||||
std::max(self_batch_sizes.size(), other_batch_sizes.size());
|
||||
auto get_batch_dim = [&](std::vector<int64_t> batch_sizes, std::size_t dim) {
|
||||
long idx = dim - num_batch_dim + batch_sizes.size();
|
||||
// Negative index means out of bounds, which defaults to a dim size of 1.
|
||||
return idx < 0 ? 1 : batch_sizes[idx];
|
||||
};
|
||||
for (std::size_t i = 0; i < num_batch_dim; i++) {
|
||||
self_size = get_batch_dim(self_batch_sizes, i);
|
||||
other_size = get_batch_dim(other_batch_sizes, i);
|
||||
|
||||
TORCH_CHECK(
|
||||
self_size == 1 || other_size == 1 || self_size == other_size,
|
||||
"At trailing dimension ", i, ", expected for dimensions ",
|
||||
"to either match or have one of them equal one, but got ", self_size,
|
||||
" and ", other_size, " instead!");
|
||||
|
||||
sizes.push_back(std::max(self_size, other_size));
|
||||
}
|
||||
|
||||
// Keep track of the inner dimensions of matmul to validate op is valid.
|
||||
std::pair<int64_t, int64_t> inner_sizes;
|
||||
if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 1) {
|
||||
// Dot-Product -- scalar output, so no dimensions inserted
|
||||
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
|
||||
} else if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 2) {
|
||||
// Vector-Matrix product (m) @ (m, n) -> (n)
|
||||
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
|
||||
|
||||
sizes.push_back(other_matrix_sizes[1]);
|
||||
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 1) {
|
||||
// Matrix-Vector product (m, n) @ (n) -> (m)
|
||||
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
|
||||
|
||||
sizes.push_back(self_matrix_sizes[0]);
|
||||
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 2) {
|
||||
// Matrix-Matrix product (m, n) @ (n, o) -> (m, o)
|
||||
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
|
||||
|
||||
sizes.push_back(self_matrix_sizes[0]);
|
||||
sizes.push_back(other_matrix_sizes[1]);
|
||||
} else {
|
||||
// By this time, self_matrix_sizes and other_matrix_sizes should have at
|
||||
// most 2 dims, so if this is executed something has gone wrong...
|
||||
TORCH_CHECK(false, "Invalid matmul shape combination!");
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
inner_sizes.first == inner_sizes.second, "Inner dimension of matrix (",
|
||||
inner_sizes.first, ") does not ", "match (", inner_sizes.second, ")!");
|
||||
|
||||
return {Shape(self.scalar_type(), sizes)};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_native_batch_norm(
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
|
||||
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var, bool training,
|
||||
double momentum, double eps) {
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
// A separate mean and var needs to be kept for each channel.
|
||||
TORCH_CHECK(
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
int64_t num_features = input.sizes().vec()[1];
|
||||
|
||||
if (running_mean.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
|
||||
if (running_var.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_var.value().scalar_type(), running_var.value().sizes().vec());
|
||||
}
|
||||
} else {
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
}
|
||||
|
||||
if (running_var.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_var.value().scalar_type(), running_var.value().sizes().vec());
|
||||
} else {
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::vector<Shape>
|
||||
compute_shape_reshape(const at::Tensor& self, at::IntArrayRef shape) {
|
||||
// Make a copy of the desired output shape.
|
||||
std::vector<int64_t> sizes(shape.begin(), shape.end());
|
||||
|
||||
// Product of all sizes in input shape is the number of entries in tensor.
|
||||
int64_t num_entries = 1;
|
||||
for (int64_t i : self.sizes().vec()) {
|
||||
num_entries *= i;
|
||||
}
|
||||
|
||||
// Validate the number of entries in the desired shape. If there is a wildcard
|
||||
// dimension, we need to find it now in order to populate it.
|
||||
long wildcard_idx = -1;
|
||||
int64_t num_concrete_entries = 1;
|
||||
for (std::size_t idx = 0; idx < sizes.size(); idx++) {
|
||||
if (sizes[idx] != -1) {
|
||||
num_concrete_entries *= sizes[idx];
|
||||
} else {
|
||||
TORCH_CHECK(wildcard_idx == -1, "only one dimension can be inferred");
|
||||
wildcard_idx = idx;
|
||||
}
|
||||
}
|
||||
|
||||
if (wildcard_idx == -1) {
|
||||
// No wildcard, the shape should already be known.
|
||||
TORCH_CHECK(
|
||||
num_entries == num_concrete_entries, "shape `[", sizes,
|
||||
"]` is invalid for input of size ", num_concrete_entries);
|
||||
} else {
|
||||
// There is one dimension which is not explicitly declared -- we need to
|
||||
// infer.
|
||||
TORCH_CHECK(
|
||||
num_entries % num_concrete_entries == 0, "shape `[", sizes,
|
||||
"]` is invalid for input of size ", num_concrete_entries);
|
||||
|
||||
sizes[wildcard_idx] = num_entries / num_concrete_entries;
|
||||
}
|
||||
|
||||
return {Shape(self.scalar_type(), sizes)};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_rsub(
|
||||
const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) {
|
||||
// Since other is scalar, the result will match tensor shape.
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape>
|
||||
compute_shape_select(const at::Tensor& self, int64_t dim, int64_t index) {
|
||||
auto original_shape = self.sizes().vec();
|
||||
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
|
||||
const at::Tensor& grad_out, const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var,
|
||||
const c10::optional<at::Tensor>& save_mean,
|
||||
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
|
||||
::std::array<bool, 3> output_mask) {
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
// A separate mean and var needs to be kept for each channel.
|
||||
TORCH_CHECK(
|
||||
dim < (int64_t)sizes.size(), "Dimension ", dim,
|
||||
" is out of bounds for tensor with ", sizes.size(), " dimensions!");
|
||||
TORCH_CHECK(
|
||||
index < sizes[dim], "Index ", index,
|
||||
" is out of bounds for dimension of size ", sizes[dim]);
|
||||
sizes.erase(sizes.begin() + dim);
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
int64_t num_features = input.sizes().vec()[1];
|
||||
|
||||
return {Shape(self.scalar_type(), sizes)};
|
||||
// `weight` and `bias` are vectors of length C (number of channels)`
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_slice(
|
||||
const at::Tensor& self, int64_t dim, c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end, int64_t step) {
|
||||
auto original_shape = self.sizes().vec();
|
||||
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
|
||||
|
||||
int64_t dim_size = sizes[dim];
|
||||
|
||||
// Index may be negative, so we must normalize it.
|
||||
int64_t start_norm = normalize_index(start.value(), dim_size);
|
||||
int64_t end_norm = normalize_index(end.value(), dim_size);
|
||||
|
||||
if (start_norm >= end_norm || start_norm >= dim_size || end_norm <= 0) {
|
||||
// Slice is out of bounds, nothing in range.
|
||||
sizes[dim] = 0;
|
||||
} else {
|
||||
// Clamp upper and lower bound to valid indices.
|
||||
start_norm = std::max((int64_t)0, start_norm);
|
||||
end_norm = std::min(dim_size, end_norm);
|
||||
|
||||
// Final size is determined by step and interval size.
|
||||
sizes[dim] = std::ceil((double)(end_norm - start_norm) / (double)step);
|
||||
}
|
||||
|
||||
return {Shape(self.scalar_type(), sizes)};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_softmax(
|
||||
const at::Tensor& self, int64_t dim, c10::optional<at::ScalarType> dtype) {
|
||||
std::vector<torch::lazy::Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
if (dtype.has_value()) {
|
||||
return {Shape(dtype.value(), self.sizes().vec())};
|
||||
return {Shape(*dtype, size)};
|
||||
}
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape>
|
||||
compute_shape_transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) {
|
||||
auto original_shape = self.sizes().vec();
|
||||
std::vector<int64_t> sizes{original_shape.begin(), original_shape.end()};
|
||||
|
||||
// Index may be negative, so we must normalize it. We create new variables
|
||||
// instead of replacing the existing ones so that in the case of an error,
|
||||
// the original values can be printed out.
|
||||
int64_t dim0_norm = normalize_index(dim0, sizes.size());
|
||||
int64_t dim1_norm = normalize_index(dim1, sizes.size());
|
||||
|
||||
// Verify dimensions are valid.
|
||||
TORCH_CHECK(
|
||||
0 <= dim0_norm && dim0_norm < (int64_t)sizes.size(), "dim0 has value ",
|
||||
dim0, ", but there are only ", sizes.size(), " tensor dimensions");
|
||||
TORCH_CHECK(
|
||||
0 <= dim1_norm && dim1_norm < (int64_t)sizes.size(), "dim1 has value ",
|
||||
dim1, ", but there are only ", sizes.size(), " tensor dimensions");
|
||||
|
||||
// Swap shapes at dimensions.
|
||||
std::swap(sizes[dim0_norm], sizes[dim1_norm]);
|
||||
|
||||
return {Shape(self.scalar_type(), sizes)};
|
||||
return {Shape(self.scalar_type(), size)};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
|
|
|
@ -22,74 +22,48 @@ namespace lazy {
|
|||
|
||||
// clang-format off
|
||||
|
||||
TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
|
||||
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
|
||||
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape_adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size);
|
||||
TORCH_API std::vector<Shape> compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||
TORCH_API std::vector<Shape> compute_shape_add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||
TORCH_API std::vector<Shape> compute_shape_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps, bool cudnn_enabled);
|
||||
TORCH_API std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, const at::Tensor & p, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, const c10::optional<at::Tensor> & weights, int64_t minlength);
|
||||
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
|
||||
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
|
||||
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
|
||||
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_dropout(const at::Tensor & input, double p, bool train);
|
||||
TORCH_API std::vector<Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
|
||||
TORCH_API std::vector<Shape> compute_shape_expand_as(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim);
|
||||
TORCH_API std::vector<Shape> compute_shape_floor_divide(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_fmod(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_full_like(const at::Tensor & self, const at::Scalar & fill_value, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
|
||||
TORCH_API std::vector<Shape> compute_shape_hardswish(const at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val);
|
||||
TORCH_API std::vector<Shape> compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
|
||||
TORCH_API std::vector<Shape> compute_shape_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps, bool cudnn_enable);
|
||||
TORCH_API std::vector<Shape> compute_shape_log_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<Shape> compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim);
|
||||
TORCH_API std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
|
||||
TORCH_API std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
|
||||
TORCH_API std::vector<Shape> compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask);
|
||||
TORCH_API std::vector<Shape> compute_shape_matmul(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_max(const at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode);
|
||||
TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
|
||||
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
|
||||
TORCH_API std::vector<Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
|
||||
TORCH_API std::vector<Shape> compute_shape_relu(const at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape_relu_(at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
|
||||
TORCH_API std::vector<Shape> compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape);
|
||||
TORCH_API std::vector<Shape> compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||
TORCH_API std::vector<Shape> compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index);
|
||||
TORCH_API std::vector<Shape> compute_shape_slice(const at::Tensor & self, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
|
||||
TORCH_API std::vector<Shape> compute_shape_softmax(const at::Tensor & self, int64_t dim, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<Shape> compute_shape_square(const at::Tensor & self);
|
||||
TORCH_API std::vector<Shape> compute_shape_std(const at::Tensor & self, bool unbiased);
|
||||
TORCH_API std::vector<Shape> compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||
TORCH_API std::vector<Shape> compute_shape_sub_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||
TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
|
||||
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
|
||||
TORCH_API std::vector<Shape> compute_shape_zero_(at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_abs(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli_functional(const at::Tensor & self, const at::Tensor & p, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bincount(const at::Tensor & self, const c10::optional<at::Tensor> & weights, int64_t minlength);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_flip(const at::Tensor & self, at::IntArrayRef dims);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_fmod(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_logical_or(const at::Tensor & self, const at::Tensor & other);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, const c10::optional<at::Tensor> & save_mean, const c10::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional<bool> train);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, ::std::array<bool,3> output_mask);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_resize_functional(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_uniform_functional(const at::Tensor & self, double from, double to, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
|
||||
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "backend_impl.h"
|
||||
#include "ir_builder.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
#include "ops/device_data.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
@ -112,7 +113,7 @@ TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
|
|||
if (!device_data_node) {
|
||||
return nullptr;
|
||||
}
|
||||
return device_data_node->data;
|
||||
return device_data_node->data();
|
||||
}
|
||||
|
||||
at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "dynamic_ir.h"
|
||||
#include "generated/LazyNonNativeIr.h"
|
||||
#include "mlir_node.h"
|
||||
#include "ops/device_data.h"
|
||||
#include "ops/generic.h"
|
||||
|
||||
// This file contains the TorchMlir IrBuilder
|
||||
|
@ -35,7 +36,7 @@ struct TorchMlirIrBuilder : IrBuilder {
|
|||
NodePtr MakeExpand(const Value& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand) const override { return MakeNode<Expand>(input0, size, is_scalar_expand); }
|
||||
NodePtr MakeView(const Value& input0, const std::vector<int64_t>& output_size) const override { return MakeNode<View>(input0, output_size); }
|
||||
NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype = c10::nullopt) const override { return MakeNode<Cast>(input0, dtype, stype); }
|
||||
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TensorList>(inputs); }
|
||||
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TorchMlirTensorList>(inputs); }
|
||||
NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const override { return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed); }
|
||||
|
||||
// view ops
|
||||
|
|
|
@ -331,8 +331,12 @@ const std::string TorchMlirComputation::to_string() const {
|
|||
ss << "Input/Output Alias Mapping: \n";
|
||||
for (InputOutputAlias input_output_alias : input_output_aliases_) {
|
||||
ss << "Output: " << input_output_alias.output_index
|
||||
<< " -> Input param: " << input_output_alias.param_number << std::endl;
|
||||
<< " -> Input param: " << input_output_alias.param_number << "\n";
|
||||
}
|
||||
ss << "\n";
|
||||
|
||||
// Mark Step
|
||||
ss << "In Mark Step: " << (in_mark_step ? "true" : "false") << "\n";
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "../utils/sys_utils.h"
|
||||
#include "LazyShapeInference.h"
|
||||
#include "generated/LazyNativeFunctions.h"
|
||||
#include "ops/to_copy.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
@ -166,6 +167,81 @@ torch::lazy::LazyTensorPtr create_view(
|
|||
return input->CreateViewTensor(std::move(view_info));
|
||||
}
|
||||
|
||||
torch::lazy::ViewInfo CreateAsStridedViewInfo(
|
||||
const torch::lazy::Shape& input_shape, std::vector<int64_t> size,
|
||||
std::vector<int64_t> stride, c10::optional<int64_t> storage_offset) {
|
||||
torch::lazy::Shape result_shape =
|
||||
torch::lazy::Shape(input_shape.scalar_type(), size);
|
||||
torch::lazy::AsStridedInfo as_strided_info;
|
||||
as_strided_info.stride = std::move(stride);
|
||||
if (storage_offset) {
|
||||
as_strided_info.offset = *storage_offset;
|
||||
}
|
||||
return torch::lazy::ViewInfo(
|
||||
torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape),
|
||||
input_shape, std::move(as_strided_info));
|
||||
}
|
||||
|
||||
torch::lazy::LazyTensorPtr lazy_narrow(
|
||||
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
||||
int64_t length) {
|
||||
auto input_shape = input->shape();
|
||||
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
|
||||
torch::lazy::Shape narrow_shape = input_shape;
|
||||
narrow_shape.set_size(dim, length);
|
||||
|
||||
torch::lazy::ViewInfo::Type view_type =
|
||||
(input_shape.Get().numel() == narrow_shape.numel())
|
||||
? torch::lazy::ViewInfo::Type::kReshape
|
||||
: torch::lazy::ViewInfo::Type::kNarrow;
|
||||
torch::lazy::ViewInfo view_info(
|
||||
view_type, std::move(narrow_shape), input_shape);
|
||||
view_info.indices[dim] =
|
||||
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
|
||||
return input->CreateViewTensor(std::move(view_info));
|
||||
}
|
||||
|
||||
torch::lazy::LazyTensorPtr lazy_view(
|
||||
const torch::lazy::LazyTensorPtr& input,
|
||||
c10::ArrayRef<int64_t> output_size) {
|
||||
auto input_shape = input->shape().Get();
|
||||
torch::lazy::Shape shape = torch::lazy::Shape(
|
||||
input_shape.scalar_type(),
|
||||
at::infer_size(output_size, input_shape.numel()));
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
|
||||
return input->CreateViewTensor(std::move(view_info));
|
||||
}
|
||||
|
||||
torch::lazy::LazyTensorPtr lazy_select(
|
||||
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t index) {
|
||||
auto shape = input->shape();
|
||||
dim = torch::lazy::GetCanonicalDimensionIndex(dim, shape.Get().dim());
|
||||
torch::lazy::LazyTensorPtr result = lazy_narrow(input, dim, index, 1);
|
||||
auto new_dims = torch::lazy::DropDimensions(shape.Get().sizes(), {dim});
|
||||
return lazy_view(result, new_dims);
|
||||
}
|
||||
|
||||
torch::lazy::LazyTensorPtr lazy_slice(
|
||||
const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start,
|
||||
int64_t end, int64_t step) {
|
||||
auto input_shape = input->shape();
|
||||
dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim());
|
||||
start =
|
||||
torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start);
|
||||
end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end);
|
||||
// PyTorch allows tensor[-1:0] to return a 0-dim tensor.
|
||||
if (start > end) {
|
||||
end = start;
|
||||
}
|
||||
step = std::min(step, end - start);
|
||||
|
||||
torch::lazy::SelectInfo select = {dim, start, end, step};
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kSelect, input_shape, select);
|
||||
return input->CreateViewTensor(std::move(view_info));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// at::Tensor LazyNativeFunctions::bernoulli(
|
||||
|
@ -194,6 +270,44 @@ torch::lazy::LazyTensorPtr create_view(
|
|||
// // return self;
|
||||
// }
|
||||
|
||||
at::Tensor LazyNativeFunctions::as_strided(
|
||||
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
auto xsize = torch::lazy::ToI64Vector(size);
|
||||
auto xstride = torch::lazy::ToI64Vector(stride);
|
||||
if (!torch::lazy::StrideIsSupported(xstride)) {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
self_tensor->CreateViewTensor(CreateAsStridedViewInfo(
|
||||
self_tensor->shape(), std::move(xsize), std::move(xstride),
|
||||
storage_offset)));
|
||||
}
|
||||
|
||||
const at::Tensor& LazyNativeFunctions::as_strided_(
|
||||
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
auto xsize = torch::lazy::ToI64Vector(size);
|
||||
auto xstride = torch::lazy::ToI64Vector(stride);
|
||||
if (!torch::lazy::StrideIsSupported(xstride)) {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
if (self_tensor->data()->view == nullptr) {
|
||||
self_tensor->SetIrValue(torch::lazy::MakeAsStrided(
|
||||
self_tensor->GetIrValue(), std::move(xsize), std::move(xstride),
|
||||
storage_offset.value_or(0)));
|
||||
} else {
|
||||
auto input_shape = self_tensor->shape();
|
||||
self_tensor->SetSubView(CreateAsStridedViewInfo(
|
||||
input_shape, std::move(xsize), std::move(xstride), storage_offset));
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto lazy_tensors = torch::lazy::GetLtcTensors(tensors);
|
||||
|
@ -298,6 +412,99 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize(
|
|||
return dst;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::_to_copy(
|
||||
const at::Tensor& self, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory, bool non_blocking,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
PRINT_FUNCTION();
|
||||
auto options = self.options();
|
||||
if (dtype) {
|
||||
// I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)...
|
||||
// because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it
|
||||
options = options.dtype(dtype);
|
||||
}
|
||||
if (layout) {
|
||||
options = options.layout(layout);
|
||||
}
|
||||
if (memory_format) {
|
||||
options = options.memory_format(memory_format);
|
||||
}
|
||||
if (pin_memory) {
|
||||
// TODO(whc) can we honor 'pin_memory' in some/all cases?
|
||||
options = options.pinned_memory(pin_memory);
|
||||
TORCH_WARN_ONCE("Pinned memory used in lazy _to_copy, check if the "
|
||||
"behavior is as intended");
|
||||
}
|
||||
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto lazy_self = torch::lazy::TryGetLtcTensor(self);
|
||||
if (!lazy_self && device && device->type() == c10::kLazy) {
|
||||
// Case 1: eager->lazy (we create a new lazy tensor)
|
||||
|
||||
auto eager_tensor =
|
||||
self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
|
||||
lazy_self = torch::lazy::GetOrCreateLtcTensor(
|
||||
eager_tensor, torch::lazy::atenDeviceToBackendDevice(*device));
|
||||
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
|
||||
} else if (device && device->type() != c10::kLazy) {
|
||||
// Case 2: lazy->eager (forces a graph break since we are materializing a tensor)
|
||||
|
||||
TORCH_INTERNAL_ASSERT(lazy_self);
|
||||
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
|
||||
options = options.device(device);
|
||||
auto moved_eager_tensor =
|
||||
eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
|
||||
return moved_eager_tensor;
|
||||
} else if (
|
||||
device && device->type() == c10::kLazy && device->has_index() &&
|
||||
device->index() != self.device().index()) {
|
||||
// Case 3: lazy:0 -> lazy:1
|
||||
|
||||
// TODO(whc) what do we actually want to do here?
|
||||
// option 1: materialize, move eager tensor, create new lazy tensor
|
||||
// - this should be our default, as it is what would happen before we implemented _to_copy
|
||||
// - actually combines case 1 + case 2
|
||||
// option 2: support multiple devices inside one lazy/TS executor (case 4)
|
||||
// - but: we may have other assumptions that there is just one device per executor? so don't take this lightly
|
||||
|
||||
TORCH_INTERNAL_ASSERT(lazy_self);
|
||||
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
|
||||
// we move the eager tensor to the 'eager' equivalent of our lazy device
|
||||
// e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use
|
||||
auto eager_device = c10::Device(
|
||||
torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index());
|
||||
options = options.device(eager_device);
|
||||
auto moved_eager_tensor =
|
||||
eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true);
|
||||
lazy_self = torch::lazy::GetOrCreateLtcTensor(
|
||||
moved_eager_tensor,
|
||||
torch::lazy::atenDeviceToBackendDevice(eager_device));
|
||||
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
|
||||
|
||||
} else {
|
||||
// Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph)
|
||||
|
||||
// Note: captured _to_copy will be executed with real eager tensors, not lazy tensors.
|
||||
// We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to
|
||||
// convert an eager tensor back to a lazy one inside the torchscript executor
|
||||
// lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument
|
||||
device = c10::nullopt;
|
||||
|
||||
auto shapes = torch::lazy::compute_shape__to_copy(
|
||||
self, dtype, layout, device, pin_memory, non_blocking, memory_format);
|
||||
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
||||
auto node = torch::lazy::MakeNode<ToCopy>(
|
||||
lazy_self->GetIrValue(), dtype, layout, device, pin_memory,
|
||||
non_blocking, memory_format, std::move(shapes));
|
||||
|
||||
auto result =
|
||||
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
|
||||
std::move(node), lazy_self->GetDevice()));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty(
|
||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
|
@ -313,6 +520,15 @@ at::Tensor LazyNativeFunctions::empty(
|
|||
return CreateLtcTensor(x_result, GetLtcDevice(device));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::empty_strided(
|
||||
at::IntArrayRef size, at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
return LazyNativeFunctions::as_strided(t, size, stride, /*storage_offset=*/0);
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::expand(
|
||||
const at::Tensor& self, at::IntArrayRef size, bool implicit) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
|
@ -355,6 +571,23 @@ LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) {
|
|||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::select(
|
||||
const at::Tensor& self, int64_t dim, int64_t index) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
lazy_select(torch::lazy::TryGetLtcTensor(self), dim, index));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::slice(
|
||||
const at::Tensor& self, int64_t dim, c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end, int64_t step) {
|
||||
int64_t start_val = start.has_value() ? start.value() : 0;
|
||||
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
return torch::lazy::CreateAtenFromLtcTensor(lazy_slice(
|
||||
torch::lazy::TryGetLtcTensor(self), dim, start_val, end_val, step));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) {
|
||||
return squeeze(self, -1);
|
||||
}
|
||||
|
@ -390,6 +623,21 @@ at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
|||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::transpose(
|
||||
const at::Tensor& self, int64_t dim0, int64_t dim1) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
auto input_shape = self_tensor->shape();
|
||||
auto permute_dims = torch::lazy::MakeTransposePermutation(
|
||||
/*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim());
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims);
|
||||
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
@ -418,6 +666,21 @@ LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) {
|
|||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::_unsafe_view(
|
||||
const at::Tensor& self, at::IntArrayRef size) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||
|
||||
auto input_shape = self_tensor->shape().Get();
|
||||
torch::lazy::Shape shape = torch::lazy::Shape(
|
||||
input_shape.scalar_type(),
|
||||
at::infer_size(torch::lazy::ToI64Vector(size), input_shape.numel()));
|
||||
torch::lazy::ViewInfo view_info(
|
||||
torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape);
|
||||
return torch::lazy::CreateAtenFromLtcTensor(
|
||||
self_tensor->CreateViewTensor(std::move(view_info)));
|
||||
}
|
||||
|
||||
void InitializeAtenBindings() {}
|
||||
|
||||
} // namespace lazy
|
||||
|
|
|
@ -76,15 +76,23 @@ TorchMlirOpVector TorchMlirNode::Lower(
|
|||
return {};
|
||||
}
|
||||
|
||||
TensorList::TensorList(OpList values)
|
||||
|
||||
OpKind TorchMlirTensorList::ClassOpKind() {
|
||||
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||
// import otherwise
|
||||
static const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
|
||||
return tensor_list_opkind;
|
||||
}
|
||||
|
||||
TorchMlirTensorList::TorchMlirTensorList(OpList values)
|
||||
: TorchMlirNode(
|
||||
/*op=*/tensor_list_opkind,
|
||||
/*op=*/TorchMlirTensorList::ClassOpKind(),
|
||||
/*operands=*/values,
|
||||
/*shapes=*/std::vector<Shape>(),
|
||||
/*num_outputs=*/1,
|
||||
/*hash_seed=*/kHashSeed) {}
|
||||
|
||||
torch::lazy::TorchMlirOpVector TensorList::Lower(
|
||||
torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::Value*> tensor_list;
|
||||
CHECK(!operands().empty());
|
||||
|
|
|
@ -42,6 +42,8 @@ public:
|
|||
TorchMlirNode(
|
||||
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed);
|
||||
|
||||
~TorchMlirNode() override = default;
|
||||
|
||||
hash_t hash() const override;
|
||||
|
||||
hash_t shapeHash() const override;
|
||||
|
@ -58,9 +60,6 @@ private:
|
|||
hash_t dag_hash_;
|
||||
};
|
||||
|
||||
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||
// import otherwise
|
||||
const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
|
||||
|
||||
// TensorList represents an at::TensorList which is a vector[Tensor] but is also
|
||||
// a first-class IValue and can be fed as a single input to a TS program. It is
|
||||
|
@ -77,9 +76,11 @@ const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
|
|||
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and
|
||||
// then implement it as NotImplemented for TensorList, also fixing the assertion
|
||||
// that would fail.
|
||||
struct TORCH_API TensorList : public TorchMlirNode {
|
||||
TensorList() = delete;
|
||||
TensorList(OpList values);
|
||||
struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
|
||||
static OpKind ClassOpKind();
|
||||
|
||||
TorchMlirTensorList() = delete;
|
||||
TorchMlirTensorList(OpList values);
|
||||
|
||||
torch::lazy::TorchMlirOpVector Lower(
|
||||
TorchMlirFunction function,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "generated/LazyNonNativeIr.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
#include "mlir_node.h"
|
||||
#include "ops/device_data.h"
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
@ -61,8 +62,8 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
|||
for (jit::Value* value : results) {
|
||||
if (value->type()->kind() == c10::TypeKind::TensorType) {
|
||||
TORCH_CHECK(
|
||||
tensor_type_idx < tensor_types.size(),
|
||||
"Tensor corresponding to JIT SSA value %", value->debugName(),
|
||||
tensor_type_idx < tensor_types.size(), function->graph()->toString(),
|
||||
"\nTensor corresponding to JIT SSA value %", value->debugName(),
|
||||
" corresponds to result #", tensor_type_idx, ", but we only have ",
|
||||
tensor_types.size(), " known types!");
|
||||
|
||||
|
@ -100,6 +101,79 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
|||
function, sym, tensor_types, arguments, kwarguments);
|
||||
}
|
||||
|
||||
c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
|
||||
auto tensor_type = value_type->cast<c10::TensorType>();
|
||||
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
|
||||
|
||||
return *tensor_type.get();
|
||||
}
|
||||
|
||||
c10::optional<std::vector<int64_t>>
|
||||
get_tensor_type_shape(c10::TensorType& tensor_type) {
|
||||
auto& symbolic_shape = tensor_type.symbolic_sizes();
|
||||
if (!symbolic_shape.rank()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
// Get current tensor shape.
|
||||
std::vector<int64_t> dims;
|
||||
dims.resize(*symbolic_shape.rank());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
auto shape_symbol = symbolic_shape[i];
|
||||
dims[i] = shape_symbol.is_static() ? shape_symbol.static_size() : -1;
|
||||
}
|
||||
|
||||
return dims;
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_copy(c10::TypePtr value_type) {
|
||||
c10::TensorType& tensor_type = cast_tensor_type(value_type);
|
||||
|
||||
auto maybe_dims = get_tensor_type_shape(tensor_type);
|
||||
TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!");
|
||||
|
||||
auto scalar_type = tensor_type.scalarType();
|
||||
TORCH_CHECK(
|
||||
scalar_type.has_value(), "Unable to copy due to lack of scalar type!");
|
||||
return {Shape(scalar_type.value(), maybe_dims.value())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_slice(
|
||||
c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end,
|
||||
int64_t step) {
|
||||
c10::TensorType& tensor_type = cast_tensor_type(value_type);
|
||||
|
||||
auto maybe_dims = get_tensor_type_shape(tensor_type);
|
||||
TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!");
|
||||
|
||||
std::vector<int64_t> dims = maybe_dims.value();
|
||||
int64_t num_dims = dims[dim];
|
||||
|
||||
// Index may be negative, so we must normalize it.
|
||||
auto normalize_index = [](int64_t index, unsigned num_dims) {
|
||||
return index < 0 ? (int64_t)num_dims + index : index;
|
||||
};
|
||||
start = normalize_index(start, num_dims);
|
||||
end = normalize_index(end, num_dims);
|
||||
|
||||
if (start >= end || start >= num_dims || end <= 0) {
|
||||
// Slice is out of bounds, nothing in range.
|
||||
dims[dim] = 0;
|
||||
} else {
|
||||
// Clamp upper and lower bound to valid indices.
|
||||
start = std::max((int64_t)0, start);
|
||||
end = std::min(num_dims, end);
|
||||
|
||||
// Final size is determined by step and interval size.
|
||||
dims[dim] = std::ceil((double)(end - start) / (double)step);
|
||||
}
|
||||
|
||||
auto scalar_type = tensor_type.scalarType();
|
||||
TORCH_CHECK(
|
||||
scalar_type.has_value(), "Unable to slice due to lack of scalar type!");
|
||||
return {Shape(scalar_type.value(), dims)};
|
||||
}
|
||||
|
||||
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
|
||||
public:
|
||||
TorchMlirNodeLowering(
|
||||
|
@ -213,14 +287,14 @@ public:
|
|||
const torch::lazy::DeviceData* device_data_node =
|
||||
torch::lazy::NodeCast<torch::lazy::DeviceData>(
|
||||
node, *torch::lazy::ltc_device_data);
|
||||
auto infoptr = device_data_node->data->info();
|
||||
auto infoptr = device_data_node->data()->info();
|
||||
auto deviceDataInfoPtr =
|
||||
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
||||
if (GRAPH_DUMP_ENABLED) {
|
||||
LOG(ERROR) << "Lowering device data node, tensor id "
|
||||
<< deviceDataInfoPtr->tensor_id << std::endl;
|
||||
}
|
||||
return {loctx()->GetParameter(device_data_node->data)};
|
||||
return {loctx()->GetParameter(device_data_node->data())};
|
||||
}
|
||||
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
|
@ -421,8 +495,8 @@ public:
|
|||
arguments.emplace_back(destination);
|
||||
arguments.emplace_back(source);
|
||||
LowerBuiltin(
|
||||
at::aten::copy_, c10::ArrayRef<Shape>({/*shape goes here*/}),
|
||||
arguments);
|
||||
at::aten::copy_,
|
||||
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), arguments);
|
||||
}
|
||||
|
||||
torch::jit::Value* GenerateSlice(
|
||||
|
@ -434,8 +508,11 @@ public:
|
|||
arguments.emplace_back(start);
|
||||
arguments.emplace_back(end);
|
||||
arguments.emplace_back(step);
|
||||
|
||||
TorchMlirOpVector selected = LowerBuiltin(
|
||||
at::aten::slice, c10::ArrayRef<Shape>({/*shape goes here*/}),
|
||||
at::aten::slice,
|
||||
c10::ArrayRef<Shape>(
|
||||
compute_shape_slice(base->type(), dim, start, end, step)),
|
||||
arguments);
|
||||
CHECK_EQ(selected.size(), 1);
|
||||
return selected.front();
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
#include <sstream>
|
||||
|
||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
|
||||
#include "device_data.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
|
||||
: TorchMlirNode(
|
||||
ClassOpKind(),
|
||||
data->shape(),
|
||||
/*num_outputs=*/1,
|
||||
/*hash_seed=*/static_cast<uint32_t>(101)),
|
||||
data_(std::move(data)) {}
|
||||
|
||||
std::string DeviceData::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TorchMlirNode::ToString() << ", device=" << data_->device();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
const DeviceData* DeviceData::Cast(const Node* node) {
|
||||
return NodeCast<DeviceData>(node);
|
||||
}
|
||||
|
||||
NodePtr DeviceData::Create(std::shared_ptr<BackendData> data) {
|
||||
NodePtr node = ReuseOrMakeNode<DeviceData>(data);
|
||||
// ReuseOrMakeNode may return a reused node which has the same shape,
|
||||
// however, we need to replace the old data_ with the new one.
|
||||
// Ditching the old data_ is safe because tracing is done iteration
|
||||
// by iteration, and after we lauch the async device execution for the
|
||||
// previous iteration, data_ in DeviceData nodes are not needed anymore.
|
||||
DeviceData* device_data = static_cast<DeviceData*>(node.get());
|
||||
device_data->SetData(data);
|
||||
return node;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,48 @@
|
|||
#pragma once
|
||||
|
||||
#include "../mlir_node.h"
|
||||
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API DeviceData : public TorchMlirNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return ltc_device_data;
|
||||
}
|
||||
|
||||
explicit DeviceData(std::shared_ptr<BackendData> data);
|
||||
|
||||
// A DeviceData node can be reused if the shape matches,
|
||||
// but we will substitute the actual data_ pointer under
|
||||
// the hood.
|
||||
bool CanBeReused(std::shared_ptr<BackendData> data) const {
|
||||
return data_->shape() == data->shape();
|
||||
}
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::shared_ptr<BackendData>& data() const {
|
||||
return data_;
|
||||
}
|
||||
|
||||
void SetData(std::shared_ptr<BackendData> data) {
|
||||
data_ = data;
|
||||
}
|
||||
|
||||
static const DeviceData* Cast(const Node* node);
|
||||
|
||||
// To reuse IR nodes, use this method to create DeviceData nodes
|
||||
// instead of calling the constructor directly.
|
||||
static NodePtr Create(std::shared_ptr<BackendData> data);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BackendData> data_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,101 @@
|
|||
//===- to_copy.h ----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// this file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ops/to_copy.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../mlir_node.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
|
||||
// This IR was copied from code-generated output, but the entire _to_copy operator
|
||||
// cannot be trivially code genereated since it is only desirable to capture IR for
|
||||
// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke
|
||||
// the aten/eager fallback necessitating directly implementing the right to(device) behavior
|
||||
class ToCopy : public torch::lazy::TorchMlirNode {
|
||||
public:
|
||||
ToCopy(const torch::lazy::Value& self, const c10::optional<at::ScalarType>& dtype, const c10::optional<at::Layout>& layout, const c10::optional<at::Device>& device, const c10::optional<bool>& pin_memory, const bool& non_blocking, const c10::optional<at::MemoryFormat>& memory_format, std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy),
|
||||
{self}, std::move(shapes),
|
||||
/* num_outputs */ 1,
|
||||
torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)),
|
||||
|
||||
dtype(dtype),
|
||||
layout(layout),
|
||||
device(device),
|
||||
pin_memory(pin_memory),
|
||||
non_blocking(non_blocking),
|
||||
memory_format(memory_format) {}
|
||||
|
||||
std::string ToString() const override {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TorchMlirNode::ToString();
|
||||
if (dtype.has_value()) {
|
||||
ss << ", dtype=" << dtype.value();
|
||||
} else {
|
||||
ss << ", dtype=null";
|
||||
}
|
||||
if (layout.has_value()) {
|
||||
ss << ", layout=" << layout.value();
|
||||
} else {
|
||||
ss << ", layout=null";
|
||||
}
|
||||
if (device.has_value()) {
|
||||
ss << ", device=" << device.value();
|
||||
} else {
|
||||
ss << ", device=null";
|
||||
}
|
||||
if (pin_memory.has_value()) {
|
||||
ss << ", pin_memory=" << pin_memory.value();
|
||||
} else {
|
||||
ss << ", pin_memory=null";
|
||||
}
|
||||
ss << ", non_blocking=" << non_blocking;
|
||||
if (memory_format.has_value()) {
|
||||
ss << ", memory_format=" << memory_format.value();
|
||||
} else {
|
||||
ss << ", memory_format=null";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||
torch::lazy::TorchMlirLoweringContext* loctx) const override {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
arguments.reserve(1);
|
||||
kwarguments.reserve(6);
|
||||
size_t i = 0;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
kwarguments.emplace_back("dtype", dtype);
|
||||
kwarguments.emplace_back("layout", layout);
|
||||
kwarguments.emplace_back("device", device);
|
||||
kwarguments.emplace_back("pin_memory", pin_memory);
|
||||
kwarguments.emplace_back("non_blocking", non_blocking);
|
||||
kwarguments.emplace_back("memory_format", memory_format);
|
||||
torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
|
||||
CHECK_EQ(_to_copy_out.size(), 1);
|
||||
|
||||
return _to_copy_out;
|
||||
|
||||
}
|
||||
|
||||
c10::optional<at::ScalarType> dtype;
|
||||
c10::optional<at::Layout> layout;
|
||||
c10::optional<at::Device> device;
|
||||
c10::optional<bool> pin_memory;
|
||||
bool non_blocking;
|
||||
c10::optional<at::MemoryFormat> memory_format;
|
||||
};
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -133,7 +133,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
auto containedTypes = c10::fmap(
|
||||
node->output()->type()->cast<c10::TupleType>()->containedTypes(),
|
||||
[&](const c10::TypePtr &t) {
|
||||
MlirType type = getMlirTypeFromTorchType(loc, t);
|
||||
MlirType type = getMlirTypeFromTorchType(loc, t, importOptions);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue