Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
# -*- Python -*-
|
|
|
|
# This file is licensed under a pytorch-style license
|
|
|
|
# See frontends/pytorch/LICENSE for license information.
|
|
|
|
|
|
|
|
# Structured similarly to code from git@github.com:pytorch/xla.git
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import collections
|
|
|
|
import lark
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import string
|
|
|
|
import sys
|
|
|
|
|
|
|
|
####
|
|
|
|
# This file parses the C++ signatures exported by pytorch and generates
|
|
|
|
# appropriate MLIR operations in a tablegen file. It also generates some of
|
|
|
|
# the more boilerplate parts of the pytorch integration. This may need to be
|
|
|
|
# run if pytorch versions change. Primarily this reads information from
|
|
|
|
# pytorch through RegistrationDeclarations.h and Functions.h. It also allows
|
|
|
|
# some local overrides (specified in aten_mlir_type.h).
|
|
|
|
# It generates: aten_mlir_type_defaults.{.cpp,.h} and ATenOps.td, which will need
|
|
|
|
# to be moved to their appropriate places.
|
|
|
|
|
|
|
|
# To run:
|
|
|
|
# python3 gen_aten_dialect.py --output_folder=. \
|
|
|
|
# ../csrc/aten_mlir_type.h \
|
|
|
|
# ${TORCH_INSTALL_PREFIX}/include/ATen/RegistrationDeclarations.h \
|
|
|
|
# ${TORCH_INSTALL_PREFIX}/include/ATen/Functions.h
|
|
|
|
|
|
|
|
|
|
|
|
def namedtuple_with_defaults(typename, field_names, default_values=()):
|
|
|
|
ntuple = collections.namedtuple(typename, field_names)
|
|
|
|
ntuple.__new__.__defaults__ = (None,) * len(ntuple._fields)
|
|
|
|
if isinstance(default_values, collections.Mapping):
|
|
|
|
prototype = ntuple(**default_values)
|
|
|
|
else:
|
|
|
|
prototype = ntuple(*default_values)
|
|
|
|
ntuple.__new__.__defaults__ = tuple(prototype)
|
|
|
|
return ntuple
|
|
|
|
|
|
|
|
|
|
|
|
class ArgTemplate(string.Template):
|
|
|
|
idpattern = r'[a-z0-9_]+'
|
|
|
|
|
|
|
|
|
|
|
|
FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig')
|
|
|
|
|
|
|
|
FuncGen = namedtuple_with_defaults(
|
|
|
|
'FuncGen',
|
|
|
|
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig'
|
|
|
|
)
|
|
|
|
|
|
|
|
FuncOpts = namedtuple_with_defaults(
|
|
|
|
'FuncOpts',
|
|
|
|
'ref_param, device_param, wparams, outfn_template, outfn_name, shape_check_indices'
|
|
|
|
)
|
|
|
|
|
|
|
|
_GRAMMAR = r"""
|
|
|
|
start: type fnname "(" params ")"
|
|
|
|
rtype: "(" rparams ")"
|
|
|
|
| TNAME
|
|
|
|
rparams: rparam
|
|
|
|
| rparam "," rparams
|
|
|
|
rparam: type param_name
|
|
|
|
type: CONST? core_type refspec?
|
|
|
|
fnname: CNAME
|
|
|
|
refspec: REF
|
|
|
|
| PTR
|
|
|
|
core_type: template
|
|
|
|
| TNAME
|
|
|
|
template: TNAME "<" typelist ">"
|
|
|
|
typelist: type
|
|
|
|
| type "," typelist
|
|
|
|
REF: "&"
|
|
|
|
PTR: "*"
|
|
|
|
CONST: "const"
|
|
|
|
TNAME: /[a-zA-Z0-9_:]+/
|
|
|
|
HEXNUMBER: /0x[0-9a-fA-F]+/
|
|
|
|
params: param
|
|
|
|
| param "," params
|
|
|
|
param: type param_name param_defval?
|
|
|
|
param_name: CNAME
|
|
|
|
|
|
|
|
param_defval: "=" init_value
|
|
|
|
init_value: "true"
|
|
|
|
| "false"
|
|
|
|
| "{}"
|
|
|
|
| NUMBER
|
|
|
|
| SIGNED_NUMBER
|
|
|
|
| HEXNUMBER
|
|
|
|
| ESCAPED_STRING
|
|
|
|
|
|
|
|
%import common.CNAME -> CNAME
|
|
|
|
%import common.NUMBER -> NUMBER
|
|
|
|
%import common.SIGNED_NUMBER -> SIGNED_NUMBER
|
|
|
|
%import common.ESCAPED_STRING -> ESCAPED_STRING
|
|
|
|
%import common.WS
|
|
|
|
%ignore WS
|
|
|
|
"""
|
|
|
|
|
|
|
|
_PARSER = lark.Lark(_GRAMMAR, parser='lalr', propagate_positions=True)
|
|
|
|
|
|
|
|
_XPARSER = lark.Lark(_GRAMMAR,
|
|
|
|
parser='lalr',
|
|
|
|
propagate_positions=True,
|
|
|
|
keep_all_tokens=True)
|
|
|
|
|
|
|
|
_TD_BLACKLIST = set([
|
|
|
|
'clone',
|
|
|
|
'to',
|
|
|
|
'copy_',
|
|
|
|
'copy',
|
|
|
|
'copy_from',
|
|
|
|
'_copy_from',
|
|
|
|
'_unsafe_view',
|
|
|
|
])
|
|
|
|
|
|
|
|
_TD_NO_OPSTATS_LIST = set([
|
|
|
|
'_log_softmax',
|
|
|
|
'_log_softmax_backward_data',
|
|
|
|
])
|
|
|
|
|
|
|
|
_FN_BLACKLIST = set([
|
|
|
|
'numel',
|
|
|
|
'ones',
|
|
|
|
'ones_like',
|
|
|
|
'result_type',
|
|
|
|
# 'zero_',
|
|
|
|
'zeros',
|
|
|
|
'zeros_like',
|
|
|
|
])
|
|
|
|
|
|
|
|
_FN_NO_DEBUG_ENTRY_LIST = set([
|
|
|
|
'empty',
|
|
|
|
'fill_',
|
|
|
|
'zero_',
|
|
|
|
])
|
|
|
|
|
|
|
|
_FN_BLACKLIST_REGEX = [
|
|
|
|
# ATEN functions
|
|
|
|
r'[^(]*cudnn',
|
|
|
|
# XLA/TPU functions
|
|
|
|
]
|
|
|
|
|
|
|
|
_FN_OUT = {
|
|
|
|
'add_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'arange_out(Tensor, Scalar, Scalar, Scalar) -> Tensor':
|
|
|
|
FuncOpts(outfn_template=ArgTemplate(
|
|
|
|
'ATenMLIRType::arange($1, $2, $3, $0.options())')),
|
|
|
|
'bitwise_not_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'clamp_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'div_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'gather_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'kthvalue_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'index_select_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'log_out':
|
|
|
|
FuncOpts(),
|
|
|
|
'topk_out':
|
|
|
|
FuncOpts(),
|
|
|
|
}
|
|
|
|
_FN_OUT = {}
|
|
|
|
|
|
|
|
# List of tuples with the regex match first, and the corresponding FuncOpts()
|
|
|
|
# second.
|
|
|
|
_FN_OUT_REGEX = []
|
|
|
|
|
|
|
|
_FN_REMAP = {
|
|
|
|
'_th_eq(Tensor, Scalar) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::eq'),
|
|
|
|
'_th_eq(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::eq'),
|
|
|
|
'_th_ge(Tensor, Scalar) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::ge'),
|
|
|
|
'_th_ge(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::ge'),
|
|
|
|
'_th_gt(Tensor, Scalar) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::gt'),
|
|
|
|
'_th_gt(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::gt'),
|
|
|
|
'_th_le(Tensor, Scalar) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::le'),
|
|
|
|
'_th_le(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::le'),
|
|
|
|
'_th_lt(Tensor, Scalar) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::lt'),
|
|
|
|
'_th_lt(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::lt'),
|
|
|
|
'_th_ne(Tensor, Scalar) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::ne'),
|
|
|
|
'_th_ne(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::ne'),
|
|
|
|
's__th_and(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::__and__',
|
|
|
|
shape_check_indices=((0, 1),)),
|
|
|
|
's__th_or(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::__or__',
|
|
|
|
shape_check_indices=((0, 1),)),
|
|
|
|
's__th_xor(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::__xor__',
|
|
|
|
shape_check_indices=((0, 1),)),
|
|
|
|
# '_s_where(Tensor, Tensor, Tensor) -> Tensor':
|
|
|
|
# FuncOpts(
|
|
|
|
# outfn_name='ATenMLIRType::where',
|
|
|
|
# shape_check_indices=(
|
|
|
|
# (0, 1),
|
|
|
|
# (0, 2),
|
|
|
|
# )),
|
|
|
|
's__th_eq(Tensor, Tensor) -> Tensor':
|
|
|
|
FuncOpts(outfn_name='ATenMLIRType::eq', shape_check_indices=((0, 1),)),
|
|
|
|
}
|
|
|
|
|
|
|
|
_TYPE_NSMAP = {
|
|
|
|
'Tensor': 'at::Tensor',
|
|
|
|
'TensorList': 'at::TensorList',
|
|
|
|
'Scalar': 'at::Scalar',
|
|
|
|
'Storage': 'at::Storage',
|
|
|
|
'IntList': 'at::IntList',
|
|
|
|
'IntArrayRef': 'at::IntArrayRef',
|
|
|
|
'Generator': 'at::Generator',
|
|
|
|
'ScalarType': 'at::ScalarType',
|
|
|
|
'TensorOptions': 'at::TensorOptions',
|
|
|
|
'SparseTensorRef': 'at::SparseTensorRef',
|
|
|
|
'Device': 'c10::Device',
|
|
|
|
'optional': 'c10::optional',
|
|
|
|
'MemoryFormat': 'at::MemoryFormat',
|
|
|
|
'QScheme': 'at::QScheme',
|
|
|
|
'ConstQuantizerPtr': 'at::ConstQuantizerPtr',
|
|
|
|
'Dimname': 'at::Dimname', # namedtensor-only
|
|
|
|
'DimnameList': 'at::DimnameList', # namedtensor-only
|
|
|
|
}
|
|
|
|
|
|
|
|
_H_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
|
|
|
|
|
|
|
|
#include <ATen/Tensor.h>
|
|
|
|
|
|
|
|
namespace torch_mlir {{
|
|
|
|
|
|
|
|
class ATenMLIRTypeDefault {{
|
|
|
|
public:
|
|
|
|
{hfuncs}
|
|
|
|
}};
|
|
|
|
|
|
|
|
void RegisterAtenTypeFunctions();
|
|
|
|
|
|
|
|
}} // namespace torch_mlir
|
|
|
|
"""
|
|
|
|
|
|
|
|
_CPP_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
|
|
|
|
#include "aten_mlir_type_default.h"
|
|
|
|
|
|
|
|
#include <ATen/Context.h>
|
|
|
|
#include <ATen/Functions.h>
|
|
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
|
|
#include <ATen/CPUGenerator.h>
|
|
|
|
|
|
|
|
#include "aten_mlir_bridge.h"
|
|
|
|
#include "aten_mlir_type.h"
|
|
|
|
|
|
|
|
namespace torch_mlir {{
|
|
|
|
|
|
|
|
{funcs}
|
|
|
|
|
|
|
|
{regs}
|
|
|
|
}} // namespace torch_mlir
|
|
|
|
"""
|
|
|
|
|
|
|
|
_torch_mlir_FUNCTIONS = {}
|
|
|
|
|
|
|
|
_CTOR_FUNCTIONS = {
|
|
|
|
'empty': '.device(at::DeviceType::CPU)',
|
|
|
|
'linspace': '.device(at::DeviceType::CPU)',
|
|
|
|
'logspace': '.device(at::DeviceType::CPU)',
|
|
|
|
'rand': '.device(at::DeviceType::CPU)',
|
|
|
|
'rand_like': '.device(at::DeviceType::CPU)',
|
|
|
|
'randn': '.device(at::DeviceType::CPU)',
|
|
|
|
'randn_like': '.device(at::DeviceType::CPU)',
|
|
|
|
'randint': '.device(at::DeviceType::CPU)',
|
|
|
|
'randint_like': '.device(at::DeviceType::CPU)',
|
|
|
|
'randperm': '.device(at::DeviceType::CPU)',
|
|
|
|
'scalar_tensor': '.device(at::DeviceType::CPU)',
|
|
|
|
}
|
|
|
|
|
|
|
|
_FUNCTION_OPTIONS = {
|
|
|
|
'slice(Tensor, int64_t, int64_t, int64_t, int64_t) -> Tensor':
|
|
|
|
FuncOpts(wparams=['self']),
|
|
|
|
}
|
|
|
|
|
|
|
|
_RESULT_NAME = 'x_result'
|
|
|
|
|
|
|
|
|
|
|
|
class Context(object):
|
|
|
|
|
|
|
|
def __init__(self, functions):
|
|
|
|
with open(functions, 'r') as ff:
|
|
|
|
self.functions_data = ff.read()
|
|
|
|
|
|
|
|
def get_function(self, name):
|
|
|
|
if self.functions_data.find(' {}('.format(name)) >= 0:
|
|
|
|
return 'at::{}'.format(name)
|
|
|
|
|
|
|
|
|
|
|
|
class StringEmit(object):
|
|
|
|
|
|
|
|
def __init__(self, sref):
|
|
|
|
self.sref = sref
|
|
|
|
self.sval = ''
|
|
|
|
self.pos = -1
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return self.sval
|
|
|
|
|
|
|
|
def advance(self, t):
|
|
|
|
start = t.column - 1
|
|
|
|
end = t.end_column - 1
|
|
|
|
pos = self.pos if self.pos >= 0 else start
|
|
|
|
if start > pos:
|
|
|
|
self.sval += self.sref[pos:start]
|
|
|
|
self.sval += t.value
|
|
|
|
self.pos = end
|
|
|
|
|
|
|
|
def skip(self, t):
|
|
|
|
self.pos = last_match(t) if self.pos >= 0 else -1
|
|
|
|
|
|
|
|
def append(self, s):
|
|
|
|
self.sval += s
|
|
|
|
self.pos = -1
|
|
|
|
|
|
|
|
|
|
|
|
class TensorFetcher(object):
|
|
|
|
|
|
|
|
def __init__(self, var_name):
|
|
|
|
self.var_name = var_name
|
|
|
|
self.tvar_name = '{}_tensors'.format(self.var_name)
|
|
|
|
self.tensors = []
|
|
|
|
self.writeable = []
|
|
|
|
|
|
|
|
def add(self, name, writeable):
|
|
|
|
if writeable:
|
|
|
|
self.writeable.append(len(self.tensors))
|
|
|
|
self.tensors.append(name)
|
|
|
|
return '{}[{}]'.format(self.var_name, len(self.tensors) - 1)
|
|
|
|
|
|
|
|
def generate_fetches(self):
|
|
|
|
code = ''
|
|
|
|
code += ' std::vector<at::Tensor> {} = {{{}}};\n'.format(
|
|
|
|
self.tvar_name, ', '.join(self.tensors))
|
|
|
|
code += (' auto {} = bridge::MLIRCreateTensorList({});\n').format(
|
|
|
|
self.var_name, self.tvar_name)
|
|
|
|
return code
|
|
|
|
|
|
|
|
def generate_updates(self):
|
|
|
|
assert (0)
|
|
|
|
code = ''
|
|
|
|
if self.writeable:
|
|
|
|
ivar_name = '{}_update_indices'.format(self.var_name)
|
|
|
|
code += ' std::vector<size_t> {} = {{{}}};\n'.format(
|
|
|
|
ivar_name, ', '.join(str(x) for x in self.writeable))
|
|
|
|
code += ' bridge::XlaUpdateTensors({}, {}, {});\n'.format(
|
|
|
|
self.tvar_name, self.var_name, ivar_name)
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def list_get(l, n):
|
|
|
|
return l[n] if n < len(l) else None
|
|
|
|
|
|
|
|
|
|
|
|
def is_blacklisted_fn(fname, mapsig):
|
|
|
|
if fname in _FN_BLACKLIST or mapsig in _FN_BLACKLIST:
|
|
|
|
return True
|
|
|
|
for frx in _FN_BLACKLIST_REGEX:
|
|
|
|
if re.match(frx, fname) or re.match(frx, mapsig):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def get_outfn_options(fname, mapsig):
|
|
|
|
for name in [fname, mapsig]:
|
|
|
|
fnopts = _FN_OUT.get(name, None)
|
|
|
|
if fnopts is not None:
|
|
|
|
return fnopts
|
|
|
|
for frx, fnopts in _FN_OUT_REGEX:
|
|
|
|
if re.match(frx, fname) or re.match(frx, mapsig):
|
|
|
|
return fnopts
|
|
|
|
|
|
|
|
|
|
|
|
def get_remapfn_options(fname, mapsig):
|
|
|
|
for name in [fname, mapsig]:
|
|
|
|
fnopts = _FN_REMAP.get(name, None)
|
|
|
|
if fnopts is not None:
|
|
|
|
return fnopts
|
|
|
|
|
|
|
|
|
|
|
|
def is_write_param(fnopts, pname, defval):
|
|
|
|
if fnopts and fnopts.wparams:
|
|
|
|
if pname in fnopts.wparams:
|
|
|
|
return True
|
|
|
|
return defval
|
|
|
|
|
|
|
|
|
|
|
|
def first_match(t):
|
|
|
|
if isinstance(t, lark.lexer.Token):
|
|
|
|
return t.column - 1
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
return first_match(t.children[0])
|
|
|
|
|
|
|
|
|
|
|
|
def last_match(t):
|
|
|
|
if isinstance(t, lark.lexer.Token):
|
|
|
|
return t.end_column - 1
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
return last_match(t.children[-1])
|
|
|
|
|
|
|
|
|
|
|
|
def for_every_token(t, fn):
|
|
|
|
if isinstance(t, lark.lexer.Token):
|
|
|
|
fn(t)
|
|
|
|
else:
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
for c in t.children:
|
|
|
|
for_every_token(c, fn)
|
|
|
|
|
|
|
|
|
|
|
|
def emit_string(t, emit, emit_fn):
|
|
|
|
status = emit_fn(t)
|
|
|
|
if status > 0:
|
|
|
|
|
|
|
|
def do_emit(tok):
|
|
|
|
emit.advance(tok)
|
|
|
|
|
|
|
|
for_every_token(t, do_emit)
|
|
|
|
elif status == 0:
|
|
|
|
if isinstance(t, lark.lexer.Token):
|
|
|
|
emit.advance(t)
|
|
|
|
else:
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
for c in t.children:
|
|
|
|
emit_string(c, emit, emit_fn)
|
|
|
|
else:
|
|
|
|
emit.skip(t)
|
|
|
|
|
|
|
|
|
|
|
|
def typed_child(t, n, ttype):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
assert n < len(t.children)
|
|
|
|
c = t.children[n]
|
|
|
|
assert isinstance(c, lark.tree.Tree)
|
|
|
|
assert c.data == ttype, t.pretty()
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
|
|
|
def rewrite_sig(tree, orig_sig, emit_fn=lambda x: 0):
|
|
|
|
emit = StringEmit(orig_sig)
|
|
|
|
emit_string(tree, emit, emit_fn)
|
|
|
|
return str(emit)
|
|
|
|
|
|
|
|
|
|
|
|
def rewrite_signature(sig, tmap):
|
|
|
|
|
|
|
|
def rewrite(t):
|
|
|
|
if t.type == 'TNAME':
|
|
|
|
new_type = tmap.get(t.value, None)
|
|
|
|
if new_type is not None:
|
|
|
|
t.value = new_type
|
|
|
|
|
|
|
|
def emit_fn(t):
|
|
|
|
if isinstance(t, lark.lexer.Token):
|
|
|
|
return 0
|
|
|
|
return -1 if t.data == 'param_defval' else 0
|
|
|
|
|
|
|
|
xtree = _XPARSER.parse(sig)
|
|
|
|
for_every_token(xtree, rewrite)
|
|
|
|
return rewrite_sig(xtree, sig, emit_fn=emit_fn)
|
|
|
|
|
|
|
|
|
|
|
|
def create_stdfunc_sig(tree, orig_sig):
|
|
|
|
|
|
|
|
def emit_fn(t):
|
|
|
|
if isinstance(t, lark.lexer.Token):
|
|
|
|
return 0
|
|
|
|
return -1 if t.data == 'param_name' else 0
|
|
|
|
|
|
|
|
emit = StringEmit(orig_sig)
|
|
|
|
# Emit full function return type.
|
|
|
|
emit_string(typed_child(tree, 0, 'type'), emit, emit_fn)
|
|
|
|
emit.append('(')
|
|
|
|
# Emit parameter list w/out parameter names.
|
|
|
|
emit_string(typed_child(tree, 3, 'params'), emit, emit_fn)
|
|
|
|
emit.append(')')
|
|
|
|
return str(emit)
|
|
|
|
|
|
|
|
|
|
|
|
def create_map_sig(tree, orig_sig):
|
|
|
|
|
|
|
|
def emit_fn(t):
|
|
|
|
if isinstance(t, lark.lexer.Token):
|
|
|
|
return -1 if t.type in ['CONST', 'REF', 'PTR'] else 0
|
|
|
|
return -1 if t.data in ['param_name', 'param_defval'] else 0
|
|
|
|
|
|
|
|
emit = StringEmit(orig_sig)
|
|
|
|
# Emit full function return type.
|
|
|
|
emit_string(typed_child(tree, 1, 'fnname'), emit, emit_fn)
|
|
|
|
emit.append('(')
|
|
|
|
# Emit parameter list w/out parameter names.
|
|
|
|
emit_string(typed_child(tree, 3, 'params'), emit, emit_fn)
|
|
|
|
emit.append(') -> ')
|
|
|
|
emit_string(typed_child(tree, 0, 'type'), emit, emit_fn)
|
|
|
|
return str(emit)
|
|
|
|
|
|
|
|
|
|
|
|
def type_core(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
for c in t.children:
|
|
|
|
if isinstance(c, lark.tree.Tree) and c.data == 'core_type':
|
|
|
|
c = c.children[0]
|
|
|
|
if isinstance(c, lark.lexer.Token):
|
|
|
|
return c.value
|
|
|
|
assert isinstance(c, lark.tree.Tree) and c.data == 'template'
|
|
|
|
return c.children[0].value
|
|
|
|
raise RuntimeError('Not a type tree: {}'.format(t))
|
|
|
|
|
|
|
|
|
|
|
|
def type_is_const(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
c = t.children[0]
|
|
|
|
return isinstance(c, lark.lexer.Token) and c.value == 'const'
|
|
|
|
|
|
|
|
|
|
|
|
def type_is_refptr(t, kind):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
c = t.children[-1]
|
|
|
|
if not isinstance(c, lark.tree.Tree) or c.data != 'refspec':
|
|
|
|
return False
|
|
|
|
c = c.children[0]
|
|
|
|
return isinstance(c, lark.lexer.Token) and c.value == kind
|
|
|
|
|
|
|
|
|
|
|
|
def extract_list(t, l):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
l.append(t.children[0])
|
|
|
|
if len(t.children) == 2:
|
|
|
|
c = t.children[1]
|
|
|
|
if isinstance(c, lark.tree.Tree) and c.data == t.data:
|
|
|
|
extract_list(c, l)
|
|
|
|
return l
|
|
|
|
|
|
|
|
|
|
|
|
def tuple_type_list(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
c = t.children[0]
|
|
|
|
assert isinstance(c, lark.tree.Tree) and c.data == 'core_type'
|
|
|
|
c = c.children[0]
|
|
|
|
assert isinstance(c, lark.tree.Tree) and c.data == 'template'
|
|
|
|
types = []
|
|
|
|
return extract_list(c.children[1], types)
|
|
|
|
|
|
|
|
|
|
|
|
def get_function_name(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
fname = t.children[1]
|
|
|
|
assert isinstance(fname, lark.tree.Tree)
|
|
|
|
assert fname.data == 'fnname'
|
|
|
|
return fname.children[0].value
|
|
|
|
|
|
|
|
|
|
|
|
def get_function_signature(t, orig_sig, namefn):
|
|
|
|
emit = StringEmit(orig_sig)
|
|
|
|
# Emit full function return type.
|
|
|
|
emit_string(typed_child(t, 0, 'type'), emit, lambda t: 0)
|
|
|
|
fnname = typed_child(t, 1, 'fnname').children[0]
|
|
|
|
xfname = namefn(fnname.value)
|
|
|
|
emit.append(' {}('.format(xfname))
|
|
|
|
# Emit parameter list w/out parameter names.
|
|
|
|
emit_string(typed_child(t, 3, 'params'), emit, lambda t: 0)
|
|
|
|
emit.append(')')
|
|
|
|
return str(emit), fnname.value, xfname
|
|
|
|
|
|
|
|
|
|
|
|
def get_parameters(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
c = t.children[2]
|
|
|
|
assert isinstance(c, lark.tree.Tree)
|
|
|
|
assert c.data == 'params'
|
|
|
|
params = []
|
|
|
|
extract_list(c, params)
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
|
|
def get_rparameters(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
params = []
|
|
|
|
print(len(t.children))
|
|
|
|
# c = t.children[3]
|
|
|
|
# assert isinstance(c, lark.tree.Tree)
|
|
|
|
# assert c.data == 'rparams'
|
|
|
|
|
|
|
|
# extract_list(c, params)
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
|
|
def param_name(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
c = t.children[1]
|
|
|
|
assert isinstance(c, lark.tree.Tree)
|
|
|
|
assert c.data == 'param_name'
|
|
|
|
token = c.children[0]
|
|
|
|
assert isinstance(token, lark.lexer.Token)
|
|
|
|
return token.value
|
|
|
|
|
|
|
|
|
|
|
|
def param_type(t):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
c = t.children[0]
|
|
|
|
assert isinstance(c, lark.tree.Tree)
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
|
|
|
def get_optional(fnopts, name, defval=None):
|
|
|
|
if fnopts is None or not hasattr(fnopts, name):
|
|
|
|
return defval
|
|
|
|
return getattr(fnopts, name, defval) or defval
|
|
|
|
|
|
|
|
|
|
|
|
def get_return_value(rtype, rname, param, var, ref_param, fnopts):
|
|
|
|
crtype = type_core(rtype)
|
|
|
|
if type_is_const(rtype) or type_is_refptr(rtype, '&'):
|
|
|
|
# If the return type is a const or a reference, return the matching
|
|
|
|
# parameter. In these cases we operated on XLA tensors data (the ATEN one),
|
|
|
|
# but the returned references are the input parameters.
|
|
|
|
assert param
|
|
|
|
return param_name(param)
|
|
|
|
elif crtype != 'Tensor':
|
|
|
|
return rname
|
|
|
|
else:
|
|
|
|
# If instead the return type is a value Tensor, we create a new one by
|
|
|
|
# wrapping the proper local variable which has been created by calling
|
|
|
|
# into the CPU tensor implementation.
|
|
|
|
return 'bridge::CreateMLIRTensor({}, bridge::GetMLIRDevice({}))'.format(
|
|
|
|
rname, get_optional(fnopts, 'device_param', param_name(ref_param)))
|
|
|
|
|
|
|
|
|
|
|
|
def get_reference_param(params, fnopts=None):
|
|
|
|
# The reference parameter is the Tensor object which we use to extract the
|
|
|
|
# result Tensor device, if any.
|
|
|
|
ref_param = None
|
|
|
|
other = None
|
|
|
|
for p in params:
|
|
|
|
ptype = param_type(p)
|
|
|
|
cptype = type_core(ptype)
|
|
|
|
pname = param_name(p)
|
|
|
|
if get_optional(fnopts, 'ref_param') == pname:
|
|
|
|
return p
|
|
|
|
if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'):
|
|
|
|
other = p
|
|
|
|
if cptype != 'Tensor':
|
|
|
|
continue
|
|
|
|
if not ref_param and (pname == 'self' or type_is_const(ptype)):
|
|
|
|
ref_param = p
|
|
|
|
other = p
|
|
|
|
return ref_param or other
|
|
|
|
|
|
|
|
|
|
|
|
def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param,
|
|
|
|
fnopts):
|
|
|
|
types = tuple_type_list(rtype)
|
|
|
|
retstr = '{}('.format(rtype_str)
|
|
|
|
for i, ttype in enumerate(types):
|
|
|
|
if i > 0:
|
|
|
|
retstr += ', '
|
|
|
|
tuple_var = 'std::get<{}>({})'.format(i, rname)
|
|
|
|
retstr += get_return_value(ttype, tuple_var, list_get(params, i),
|
|
|
|
list_get(param_vars, i), ref_param, fnopts)
|
|
|
|
return retstr + ')'
|
|
|
|
|
|
|
|
|
|
|
|
def get_return_type_str(t, orig_sig):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
fname = t.children[1]
|
|
|
|
assert isinstance(fname, lark.tree.Tree)
|
|
|
|
assert fname.data == 'fnname'
|
|
|
|
token = fname.children[0]
|
|
|
|
assert isinstance(token, lark.lexer.Token)
|
|
|
|
return orig_sig[0:token.column - 2]
|
|
|
|
|
|
|
|
|
|
|
|
def generate_entry_debug_code(t, fname, params, fname_ns='aten'):
|
|
|
|
code = ''
|
|
|
|
if fname in _FN_NO_DEBUG_ENTRY_LIST:
|
|
|
|
return code
|
|
|
|
code += ' std::cout << "{}::{}" << std::endl;\n'.format(fname_ns, fname)
|
|
|
|
# Emits debug code for a given intercepted ATEN type function. For now we use
|
|
|
|
# a counter which will show up in the metrics reports.
|
|
|
|
# VLOG info. Use the following to see debug output:
|
|
|
|
# export TF_CPP_VMODULE=aten_mlir_type_default=3
|
|
|
|
#code += ' TF_VLOG(3) << "XLA {} :"'.format(fname)
|
|
|
|
#for p in params:
|
|
|
|
# ptype = param_type(p)
|
|
|
|
# cptype = type_core(ptype)
|
|
|
|
# pname = param_name(p)
|
|
|
|
# if cptype == 'Tensor':
|
|
|
|
# code += ' << " {}=" << {}.toString()'.format(pname, pname)
|
|
|
|
#code += ';\n'
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def generate_exit_debug_code(t, fname, rname, params, param_vars):
|
|
|
|
code = ''
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars,
|
|
|
|
ref_param, fnopts):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
rtype = t.children[0]
|
|
|
|
ctype = type_core(rtype)
|
|
|
|
if ctype == 'std::tuple':
|
|
|
|
retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars,
|
|
|
|
ref_param, fnopts)
|
|
|
|
elif ctype == 'std::vector':
|
|
|
|
#retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format(
|
|
|
|
# rname, get_optional(fnopts, 'device_param', param_name(ref_param)))
|
|
|
|
retstr = rname
|
|
|
|
elif ctype == 'Tensor':
|
|
|
|
retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param,
|
|
|
|
fnopts)
|
|
|
|
elif ctype == 'void' and not type_is_refptr(rtype, '*'):
|
|
|
|
return ''
|
|
|
|
else:
|
|
|
|
retstr = rname
|
|
|
|
return ' return {};\n'.format(retstr)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_result_assignment(t, rname):
|
|
|
|
assert isinstance(t, lark.tree.Tree)
|
|
|
|
rtype = t.children[0]
|
|
|
|
ctype = type_core(rtype)
|
|
|
|
if ctype == 'void' and not type_is_refptr(rtype, '*'):
|
|
|
|
return ''
|
|
|
|
return 'auto&& {} = '.format(rname)
|
|
|
|
|
|
|
|
|
|
|
|
def get_handling_function(ctx, fname, the_ref_param, param_vars):
|
|
|
|
function = _torch_mlir_FUNCTIONS.get(fname, None) or ctx.get_function(fname)
|
|
|
|
if function:
|
|
|
|
code = '{}({})'.format(function, ', '.join(param_vars))
|
|
|
|
else:
|
|
|
|
other_params = list(param_vars)
|
|
|
|
other_params.remove(the_ref_param)
|
|
|
|
code = '{}.{}({})'.format(the_ref_param, fname, ', '.join(other_params))
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def rewrite_tensor_options(fname, pname):
|
|
|
|
rw = _CTOR_FUNCTIONS.get(fname, None)
|
|
|
|
if rw is None:
|
|
|
|
return '', pname
|
|
|
|
xname = 'o_{}'.format(pname)
|
|
|
|
code = ' at::TensorOptions {} = {}{};\n'.format(xname, pname, rw)
|
|
|
|
return code, xname
|
|
|
|
|
|
|
|
|
|
|
|
def get_param_names(params):
|
|
|
|
param_vars = []
|
|
|
|
for p in params:
|
|
|
|
pname = param_name(p)
|
|
|
|
param_vars.append(pname)
|
|
|
|
return param_vars
|
|
|
|
|
|
|
|
|
|
|
|
def expand_fn_template(tmpl, param_vars):
|
|
|
|
mdict = {}
|
|
|
|
for i, pname in enumerate(param_vars):
|
|
|
|
mdict[str(i)] = pname
|
|
|
|
return tmpl.substitute(mdict)
|
|
|
|
|
|
|
|
|
|
|
|
def create_call(fname, param_vars):
|
|
|
|
return '{}({})'.format(fname, ', '.join(param_vars))
|
|
|
|
|
|
|
|
|
|
|
|
def generate_shape_checks(param_vars, shape_check_indices, fname):
|
|
|
|
code = ''
|
|
|
|
#for i, j in shape_check_indices:
|
|
|
|
# code += (' XLA_CHECK({}.sizes() == {}.sizes()) << "Operand shapes must be '
|
|
|
|
# 'identical for {}, mismatch for arguments {} and {}";\n').format(
|
|
|
|
# param_vars[i], param_vars[j], fname, i + 1, j + 1)
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def generate_aten_remap(ctx, fname, sig, params, fnopts):
|
|
|
|
code = '{} {{\n'.format(sig)
|
|
|
|
|
|
|
|
param_vars = get_param_names(params)
|
|
|
|
if fnopts.outfn_template is not None:
|
|
|
|
fcall = expand_fn_template(fnopts.outfn_template, param_vars)
|
|
|
|
else:
|
|
|
|
assert fnopts.outfn_name
|
|
|
|
fcall = create_call(fnopts.outfn_name, param_vars)
|
|
|
|
|
|
|
|
if fnopts.shape_check_indices is not None:
|
|
|
|
code += generate_shape_checks(param_vars, fnopts.shape_check_indices, fname)
|
|
|
|
code += ' return {};\n'.format(fcall)
|
|
|
|
code += '}'
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def generate_outfn_result_copy(dest, src):
|
|
|
|
return ' {}.unsafeGetTensorImpl()->shallow_copy_from({}.getIntrusivePtr());\n'.format(
|
|
|
|
dest, src)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
|
|
|
|
rtype = tree.children[0]
|
|
|
|
num_outputs = None
|
|
|
|
if type_core(rtype) == 'std::tuple':
|
|
|
|
num_outputs = len(tuple_type_list(rtype))
|
|
|
|
|
|
|
|
code = '{} {{\n'.format(sig)
|
|
|
|
code += generate_entry_debug_code(tree, fname, params)
|
|
|
|
|
|
|
|
param_vars = get_param_names(params)
|
|
|
|
if fnopts.outfn_template is not None:
|
|
|
|
fcall = expand_fn_template(fnopts.outfn_template, param_vars)
|
|
|
|
else:
|
|
|
|
m = re.match(r'(.*)_out$', fname)
|
|
|
|
assert m is not None, fname
|
|
|
|
out_count = num_outputs if num_outputs is not None else 1
|
|
|
|
fcall = create_call('ATenMLIRType::{}'.format(m.group(1)),
|
|
|
|
param_vars[out_count:])
|
|
|
|
|
|
|
|
tmp_result = '{}_tmp'.format(fname)
|
|
|
|
code += ' auto {} = {};\n'.format(tmp_result, fcall)
|
|
|
|
if num_outputs is None:
|
|
|
|
code += generate_outfn_result_copy(param_vars[0], tmp_result)
|
|
|
|
code += generate_exit_debug_code(tree, fname, param_vars[0], params,
|
|
|
|
param_vars)
|
|
|
|
code += ' return {};\n'.format(param_vars[0])
|
|
|
|
else:
|
|
|
|
for i in range(0, num_outputs):
|
|
|
|
code += generate_outfn_result_copy(
|
|
|
|
param_vars[i], 'std::get<{}>({})'.format(i, tmp_result))
|
|
|
|
code += generate_exit_debug_code(tree, fname, param_vars[0:num_outputs],
|
|
|
|
params, param_vars)
|
|
|
|
code += ' return {}('.format(get_return_type_str(rwxtree, rwsig))
|
|
|
|
for i in range(0, num_outputs):
|
|
|
|
if i > 0:
|
|
|
|
code += ', '
|
|
|
|
code += param_vars[i]
|
|
|
|
code += ');\n'
|
|
|
|
code += '}'
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig, params,
|
|
|
|
fnopts):
|
|
|
|
ref_param = get_reference_param(params, fnopts=fnopts)
|
|
|
|
|
|
|
|
code = '{} {{\n'.format(sig)
|
|
|
|
code += generate_entry_debug_code(tree, fname, params)
|
|
|
|
the_ref_param = param_name(ref_param) if ref_param else None
|
|
|
|
tfetcher = TensorFetcher('mlirtens')
|
|
|
|
param_vars = []
|
|
|
|
for p in params:
|
|
|
|
ptype = param_type(p)
|
|
|
|
cptype = type_core(ptype)
|
|
|
|
pname = param_name(p)
|
|
|
|
if cptype == 'TensorList':
|
|
|
|
#xname = 'l_{}'.format(pname)
|
|
|
|
#code += (' auto {} = bridge::XlaCreateTensorList({});\n').format(
|
|
|
|
# xname, pname)
|
|
|
|
xname = pname
|
|
|
|
param_vars.append(xname)
|
|
|
|
elif cptype == 'TensorOptions':
|
|
|
|
gcode, xname = rewrite_tensor_options(fname, pname)
|
|
|
|
code += gcode
|
|
|
|
param_vars.append(xname)
|
|
|
|
elif cptype != 'Tensor':
|
|
|
|
param_vars.append(pname)
|
|
|
|
elif type_is_const(ptype):
|
|
|
|
xname = tfetcher.add(pname, is_write_param(fnopts, pname, False))
|
|
|
|
param_vars.append(xname)
|
|
|
|
else:
|
|
|
|
xname = tfetcher.add(pname, is_write_param(fnopts, pname, True))
|
|
|
|
param_vars.append(xname)
|
|
|
|
if p == ref_param and not get_optional(fnopts, 'ref_param'):
|
|
|
|
the_ref_param = param_vars[-1]
|
|
|
|
code += tfetcher.generate_fetches()
|
|
|
|
result_assign = generate_result_assignment(tree, _RESULT_NAME)
|
|
|
|
code += ' {}{};\n'.format(
|
|
|
|
result_assign, get_handling_function(ctx, fname, the_ref_param,
|
|
|
|
param_vars))
|
|
|
|
#code += tfetcher.generate_updates()
|
|
|
|
if result_assign:
|
|
|
|
code += (' static_cast<void>({}); // Avoid warnings in case not '
|
|
|
|
'used\n'.format(_RESULT_NAME))
|
|
|
|
code += generate_exit_debug_code(tree, fname,
|
|
|
|
_RESULT_NAME if result_assign else None,
|
|
|
|
params, param_vars)
|
|
|
|
code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname,
|
|
|
|
_RESULT_NAME if result_assign else None, params,
|
|
|
|
param_vars, ref_param, fnopts)
|
|
|
|
code += '}'
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def get_mlir_wrapper(fndef, ctx):
|
|
|
|
tree = _PARSER.parse(fndef.cpp_sig)
|
|
|
|
xtree = _XPARSER.parse(fndef.cpp_sig)
|
|
|
|
mapsig = create_map_sig(xtree, fndef.cpp_sig)
|
|
|
|
rwsig = rewrite_signature(fndef.cpp_sig, _TYPE_NSMAP)
|
|
|
|
rwxtree = _XPARSER.parse(rwsig)
|
|
|
|
params = get_parameters(tree)
|
|
|
|
fnopts = _FUNCTION_OPTIONS.get(mapsig, None)
|
|
|
|
|
|
|
|
def gen_fnname(x):
|
|
|
|
return 'ATenMLIRTypeDefault::{}'.format(x)
|
|
|
|
|
|
|
|
sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname)
|
|
|
|
if not is_blacklisted_fn(fname, mapsig):
|
|
|
|
ofnopts = get_outfn_options(fname, mapsig)
|
|
|
|
rfnopts = get_remapfn_options(fname, mapsig)
|
|
|
|
if ofnopts is not None:
|
|
|
|
#print ("gen_aten_out:", fname)
|
|
|
|
code = generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params,
|
|
|
|
ofnopts)
|
|
|
|
elif rfnopts is not None:
|
|
|
|
#print ("gen_aten_remap", fname)
|
|
|
|
code = generate_aten_remap(ctx, fname, sig, params, rfnopts)
|
|
|
|
else:
|
|
|
|
code = generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig,
|
|
|
|
params, fnopts)
|
|
|
|
else:
|
|
|
|
code = None
|
|
|
|
return FuncGen(tree=tree,
|
|
|
|
xtree=xtree,
|
|
|
|
rwxtree=rwxtree,
|
|
|
|
func=fname,
|
|
|
|
xfunc=xfname,
|
|
|
|
code=code,
|
|
|
|
sig=fndef.cpp_sig,
|
|
|
|
rwsig=rwsig,
|
|
|
|
cppsig=sig,
|
|
|
|
mapsig=mapsig,
|
|
|
|
funsig=create_stdfunc_sig(rwxtree, rwsig),
|
|
|
|
aten_sig=fndef.aten_sig)
|
|
|
|
|
|
|
|
|
|
|
|
def is_tensor_api(fndef):
|
|
|
|
fndef = fndef.replace('at::', '')
|
|
|
|
fndef = fndef.replace('c10::Device', 'Device')
|
|
|
|
m = re.search(r'\bTensor\b', fndef)
|
|
|
|
return m is not None, fndef
|
|
|
|
|
|
|
|
|
|
|
|
def extract_functions(path):
|
|
|
|
functions = []
|
|
|
|
errors = []
|
|
|
|
for line in open(path, 'r'):
|
|
|
|
m = re.match(r'\s*([^\s].*); //\s+(.*)', line)
|
|
|
|
if not m:
|
|
|
|
continue
|
|
|
|
fndef = m.group(1)
|
|
|
|
try:
|
|
|
|
_XPARSER.parse(fndef)
|
|
|
|
functions.append(FuncDef(cpp_sig=fndef, aten_sig=m.group(2)))
|
|
|
|
except Exception as e:
|
|
|
|
if is_tensor_api(fndef)[0]:
|
|
|
|
errors.append((fndef, str(e)))
|
|
|
|
print('Error parsing "{}": {}'.format(fndef, e), file=sys.stderr)
|
|
|
|
return functions, errors
|
|
|
|
|
|
|
|
|
|
|
|
def get_mapsig_key(mapsig):
|
|
|
|
# PyTorch generates std::tuple<> without space among the tuple types,
|
|
|
|
# which would require special understanding in the string rewriter.
|
|
|
|
# Since we are using this as simple key, we can just string the spaces.
|
|
|
|
return mapsig.replace(' ', '')
|
|
|
|
|
|
|
|
|
|
|
|
def parse_local_overrides(path):
|
|
|
|
functions = []
|
|
|
|
fndef = None
|
|
|
|
for line in open(path, 'r'):
|
|
|
|
line = line.strip()
|
|
|
|
if not fndef:
|
|
|
|
m = re.match(r'static\s+(.*);', line)
|
|
|
|
if m:
|
|
|
|
functions.append(m.group(1))
|
|
|
|
continue
|
|
|
|
m = re.match(r'static\s+(.*)', line)
|
|
|
|
if m:
|
|
|
|
fndef = m.group(1)
|
|
|
|
else:
|
|
|
|
fndef = '{} {}'.format(fndef, line)
|
|
|
|
if fndef.endswith(';'):
|
|
|
|
functions.append(fndef[:-1])
|
|
|
|
fndef = None
|
|
|
|
assert fndef is None
|
|
|
|
|
|
|
|
overrides = {}
|
|
|
|
for fndef in functions:
|
|
|
|
# Discard static XLA type functions which are not ATEN.
|
|
|
|
is_tensor, fndef = is_tensor_api(fndef)
|
|
|
|
if is_tensor:
|
|
|
|
xtree = _XPARSER.parse(fndef)
|
|
|
|
mapsig_key = get_mapsig_key(create_map_sig(xtree, fndef))
|
|
|
|
overrides[mapsig_key] = fndef
|
|
|
|
return overrides
|
|
|
|
|
|
|
|
|
|
|
|
def get_dialect_name(func):
|
|
|
|
name = ''
|
|
|
|
upper = True
|
|
|
|
cs = list(func)
|
|
|
|
for c in cs:
|
|
|
|
if c == '_':
|
|
|
|
upper = True
|
|
|
|
elif upper:
|
|
|
|
name += str(c).upper()
|
|
|
|
upper = False
|
|
|
|
else:
|
|
|
|
name += c
|
|
|
|
if cs[-1] == "_":
|
|
|
|
name += "Under"
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
|
def generate_td_functions(fgens, overrides):
|
|
|
|
code = ''
|
|
|
|
overridden = set()
|
|
|
|
|
|
|
|
code += "#ifdef ATEN_OP_DEFS\n"
|
|
|
|
code += "#else\n"
|
|
|
|
code += "#define ATEN_OP_DEFS\n\n"
|
|
|
|
|
|
|
|
for fgen in fgens:
|
|
|
|
mapsig_key = get_mapsig_key(fgen.mapsig)
|
|
|
|
if mapsig_key in overrides:
|
|
|
|
overridden.add(mapsig_key)
|
|
|
|
if fgen.func in _TD_BLACKLIST:
|
|
|
|
continue
|
|
|
|
|
|
|
|
rtype = fgen.tree.children[0]
|
|
|
|
num_outputs = 1
|
|
|
|
if type_core(rtype) == 'std::tuple':
|
|
|
|
num_outputs = len(tuple_type_list(rtype))
|
|
|
|
#print(num_outputs, rtype)
|
|
|
|
|
|
|
|
dialect_name = get_dialect_name(fgen.func)
|
|
|
|
#print ('"{}"'.format(dialect_name))
|
|
|
|
code += 'def aten_{}Op: aten_Op<"{}"'.format(dialect_name, fgen.func)
|
|
|
|
code += ', [NoSideEffect'
|
|
|
|
if not fgen.func in _TD_NO_OPSTATS_LIST:
|
|
|
|
code += ', StatisticsOpInterface'
|
|
|
|
code += ']>,\n'
|
|
|
|
code += ' Results<(outs'
|
|
|
|
# foreach output
|
|
|
|
# rparams = get_rparameters(fgen.tree)
|
|
|
|
# for p in rparams:
|
|
|
|
# pname = param_name(p)
|
|
|
|
# ptype = param_type(p)
|
|
|
|
# cptype = type_core(ptype)
|
|
|
|
# print(pname)
|
|
|
|
code += ' AnyTensor'
|
|
|
|
for i in range(num_outputs - 1):
|
|
|
|
code += ', AnyTensor'
|
|
|
|
code += ')> {\n'
|
|
|
|
code += ' let arguments = (\n'
|
|
|
|
params = get_parameters(fgen.tree)
|
|
|
|
for p in params:
|
|
|
|
pname = param_name(p)
|
|
|
|
ptype = param_type(p)
|
|
|
|
cptype = type_core(ptype)
|
|
|
|
if (cptype == 'Tensor'):
|
|
|
|
td_type = "AnyTensor"
|
|
|
|
elif (cptype == 'Scalar' or cptype == 'int64_t' or cptype == 'double' or
|
|
|
|
cptype == 'bool'):
|
|
|
|
td_type = "AnyScalar"
|
|
|
|
elif (cptype == 'c10::optional' or cptype == 'std::array'):
|
|
|
|
continue
|
|
|
|
elif (cptype == 'IntArrayRef'):
|
|
|
|
td_type = "AnyType"
|
|
|
|
else:
|
|
|
|
print('unhandled type', cptype)
|
|
|
|
td_type = "AnyType"
|
|
|
|
if p == params[0]:
|
|
|
|
code += ' ins {}:${}'.format(td_type, pname)
|
|
|
|
else:
|
|
|
|
code += ',\n {}:${}'.format(td_type, pname)
|
|
|
|
code += '\n );\n'
|
|
|
|
code += ' let summary = "aten {} operator";\n'.format(fgen.func)
|
|
|
|
code += ' let description = [{\n'
|
|
|
|
code += ' {}Op\n'.format(dialect_name)
|
|
|
|
code += ' aten {} operator\n'.format(fgen.func)
|
|
|
|
code += ' }];\n'
|
|
|
|
if not fgen.func in _TD_NO_OPSTATS_LIST:
|
|
|
|
code += ' let extraClassDeclaration = [{\n'
|
|
|
|
code += ' std::map<std::string, uint64_t> getStatistics();\n'
|
|
|
|
code += ' }];\n'
|
|
|
|
code += '}\n\n'
|
|
|
|
|
|
|
|
code += "#endif\n"
|
|
|
|
return code, overridden
|
|
|
|
|
|
|
|
|
|
|
|
def generate_registrations(fgens, overrides):
|
|
|
|
code = 'void RegisterAtenTypeFunctions() {\n'
|
|
|
|
code += ' static auto dispatch = torch::RegisterOperators()\n'
|
|
|
|
overridden = set()
|
|
|
|
for fgen in fgens:
|
|
|
|
mapsig_key = get_mapsig_key(fgen.mapsig)
|
|
|
|
if mapsig_key in overrides:
|
|
|
|
override_fn = 'ATenMLIRType::{}'.format(fgen.func)
|
|
|
|
overridden.add(mapsig_key)
|
|
|
|
else:
|
|
|
|
override_fn = fgen.xfunc if fgen.code else None
|
|
|
|
if override_fn:
|
|
|
|
code += (
|
|
|
|
' .op(torch::RegisterOperators::options().schema("{}")\n '
|
|
|
|
'.impl_unboxedOnlyKernel<{}, &{}>(at::TensorTypeId::XLATensorId)\n'
|
|
|
|
' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format(
|
|
|
|
fgen.aten_sig, fgen.funsig, override_fn, override_fn,
|
|
|
|
fgen.aten_sig))
|
|
|
|
return code + ';\n}\n', overridden
|
|
|
|
|
|
|
|
|
|
|
|
def generate_functions(fgens):
|
|
|
|
code = ''
|
|
|
|
for fgen in fgens:
|
|
|
|
if fgen.code:
|
|
|
|
code += '{}\n\n'.format(fgen.code)
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def generate_class_functions(fgens):
|
|
|
|
code = ''
|
|
|
|
for fgen in fgens:
|
|
|
|
if fgen.code:
|
|
|
|
code += ' static {};\n'.format(fgen.rwsig)
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
def gen_output_file(args, name):
|
|
|
|
if not args.output_folder:
|
|
|
|
return sys.stdout
|
|
|
|
return open(os.path.join(args.output_folder, name), 'w')
|
|
|
|
|
|
|
|
|
|
|
|
def gen_h_output_file(args):
|
|
|
|
return gen_output_file(args, 'aten_mlir_type_default.h')
|
|
|
|
|
|
|
|
|
|
|
|
def gen_cpp_output_file(args):
|
|
|
|
return gen_output_file(args, 'aten_mlir_type_default.cpp')
|
|
|
|
|
|
|
|
|
|
|
|
def gen_td_output_file(args):
|
|
|
|
return gen_output_file(args, 'ATenOps.td')
|
|
|
|
|
|
|
|
|
2020-08-27 03:55:16 +08:00
|
|
|
def check_overrides(availagle_fgens, overrides, overridden):
|
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
misses = 0
|
|
|
|
for mapsig, cpp_sig in overrides.items():
|
|
|
|
mapsig_key = get_mapsig_key(mapsig)
|
|
|
|
if not mapsig_key in overridden:
|
|
|
|
misses += 1
|
2020-08-27 03:55:16 +08:00
|
|
|
print('ERROR: ATenMLIRType function missed override:\n'
|
|
|
|
' CPPSIG: {}\n'
|
|
|
|
' MAPSIG: {}\n'
|
|
|
|
' KEY : {}\n'.format(cpp_sig, mapsig, mapsig_key),
|
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
file=sys.stderr)
|
2020-08-27 03:55:16 +08:00
|
|
|
if misses != 0:
|
|
|
|
print('Some required overrides were missing (see above).')
|
|
|
|
print('Available overrides:')
|
|
|
|
for fgen in availagle_fgens:
|
|
|
|
print(' ', get_mapsig_key(fgen.mapsig))
|
|
|
|
|
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
return misses == 0
|
|
|
|
|
|
|
|
|
|
|
|
def generate(args):
|
|
|
|
fndefs, errors = extract_functions(args.typedef)
|
|
|
|
print('Extracted {} functions ({} errors) from {}'.format(
|
|
|
|
len(fndefs), len(errors), args.typedef),
|
|
|
|
file=sys.stderr)
|
|
|
|
assert len(errors) == 0
|
|
|
|
|
2020-08-27 03:55:16 +08:00
|
|
|
local_overrides = parse_local_overrides(args.overridetype)
|
|
|
|
print('{} function overrides in {}'.format(len(local_overrides),
|
|
|
|
args.overridetype),
|
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
file=sys.stderr)
|
|
|
|
|
|
|
|
fgens = []
|
|
|
|
ctx = Context(args.functions)
|
|
|
|
for ts in fndefs:
|
|
|
|
try:
|
|
|
|
fgen = get_mlir_wrapper(ts, ctx)
|
|
|
|
if fgen:
|
|
|
|
fgens.append(fgen)
|
|
|
|
except Exception as e:
|
|
|
|
print('Failed to generate wrapper for {}: {}'.format(ts, e),
|
|
|
|
file=sys.stderr)
|
|
|
|
print('Generated {} wrappers for {}'.format(len(fgens), args.typedef),
|
|
|
|
file=sys.stderr)
|
|
|
|
|
|
|
|
functions = generate_functions(fgens)
|
|
|
|
hfunctions = generate_class_functions(fgens)
|
|
|
|
|
2020-08-27 03:55:16 +08:00
|
|
|
tdfunctions, overridden = generate_td_functions(fgens, local_overrides)
|
|
|
|
assert check_overrides(
|
|
|
|
fgens,
|
|
|
|
local_overrides,
|
|
|
|
overridden), ('Missing overrides when generating td functions')
|
|
|
|
|
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
#print(tdfunctions)
|
|
|
|
|
2020-08-27 03:55:16 +08:00
|
|
|
regs, overridden = generate_registrations(fgens, local_overrides)
|
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
#print (len(overrides), len(overridden))
|
2020-08-27 03:55:16 +08:00
|
|
|
assert check_overrides(
|
|
|
|
fgens,
|
|
|
|
local_overrides,
|
|
|
|
overridden), ('Missing local overrides when generating registrations')
|
|
|
|
|
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled
after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar
to a gpu device or the xla backend). Usage is intended to be something like:
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4), device=dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
In this case t2_cpu would contain the result of the computation, and t2_mlir
contains the mlir description of the computation. Note that this also
properly returns backward paths synthesized by pytorch. There are several
parts of this:
1) A tensor type (implemented by tensor.* and tensor_impl.*)
2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*)
3) a temporary IR (implemented by ir.cpp)
There is also a reference lowering directly from the ATen dialect to C
function calls consisting of two parts:
1) The driver that uses the IR to generate MLIR, run Passes and compile the
result using mlir::ExecutionEngine (implemented by jit.cpp and
mlir_gen.cpp)
2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations
are implemented by callbacks into the torch C++ libraries.
Some aspects of this are known to be less than optimal, in particular:
1) There's some function definitions that don't live in the file corresponding
to their declaration.
2) More aspects of this (e.g. the IR) seem like they should be automatically
generated.
3) It's unclear to me how much of the 'IR' is actually necessary, or whether
MLIR could be created on the fly.
Note that this code is licensed in a way similar to pytorch, with the
intention that eventually (when npcomp reaches some maturity) it should be
pushed there. (see frontends/pytorch/LICENSE) The code is also structured
much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
|
|
|
# Create output files ...
|
|
|
|
print(_H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=hfunctions),
|
|
|
|
file=gen_h_output_file(args))
|
|
|
|
print(_CPP_HEADER.format(gen=os.path.basename(sys.argv[0]),
|
|
|
|
funcs=functions,
|
|
|
|
regs=regs),
|
|
|
|
file=gen_cpp_output_file(args))
|
|
|
|
|
|
|
|
with gen_td_output_file(args) as f:
|
|
|
|
f.write(tdfunctions)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
arg_parser = argparse.ArgumentParser()
|
|
|
|
arg_parser.add_argument('--output_folder', type=str)
|
|
|
|
arg_parser.add_argument('overridetype',
|
|
|
|
type=str,
|
|
|
|
metavar='OVERRIDE_TYPE_FILE',
|
|
|
|
help='The path to the overrides file')
|
|
|
|
arg_parser.add_argument('typedef',
|
|
|
|
type=str,
|
|
|
|
metavar='TYPE_DEFAULT_FILE',
|
|
|
|
help='The path to the TypeDefault.h file')
|
|
|
|
arg_parser.add_argument('functions',
|
|
|
|
type=str,
|
|
|
|
metavar='FUNCTIONS_FILE',
|
|
|
|
help='The path to the Functions.h file')
|
|
|
|
args, files = arg_parser.parse_known_args()
|
|
|
|
generate(args)
|