torch-mlir/frontends/pytorch/utils/gen_aten_dialect.py

1245 lines
36 KiB
Python
Raw Normal View History

Add pytorch interface to ATen Dialect (#30) This patch adds a pytorch interface to npcomp. This interface is modeled after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar to a gpu device or the xla backend). Usage is intended to be something like: dev = torch_mlir.mlir_device() t0 = torch.randn((4,4), device=dev) t1 = torch.randn((4,4), device=dev) t2 = t0 + t1 t2_mlir = torch_mlir.get_mlir( t2 ) t2_cpu = t2.to('cpu') In this case t2_cpu would contain the result of the computation, and t2_mlir contains the mlir description of the computation. Note that this also properly returns backward paths synthesized by pytorch. There are several parts of this: 1) A tensor type (implemented by tensor.* and tensor_impl.*) 2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*) 3) a temporary IR (implemented by ir.cpp) There is also a reference lowering directly from the ATen dialect to C function calls consisting of two parts: 1) The driver that uses the IR to generate MLIR, run Passes and compile the result using mlir::ExecutionEngine (implemented by jit.cpp and mlir_gen.cpp) 2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations are implemented by callbacks into the torch C++ libraries. Some aspects of this are known to be less than optimal, in particular: 1) There's some function definitions that don't live in the file corresponding to their declaration. 2) More aspects of this (e.g. the IR) seem like they should be automatically generated. 3) It's unclear to me how much of the 'IR' is actually necessary, or whether MLIR could be created on the fly. Note that this code is licensed in a way similar to pytorch, with the intention that eventually (when npcomp reaches some maturity) it should be pushed there. (see frontends/pytorch/LICENSE) The code is also structured much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# Structured similarly to code from git@github.com:pytorch/xla.git
from __future__ import print_function
import argparse
import collections
import lark
import os
import re
import string
import sys
####
# This file parses the C++ signatures exported by pytorch and generates
# appropriate MLIR operations in a tablegen file. It also generates some of
# the more boilerplate parts of the pytorch integration. This may need to be
# run if pytorch versions change. Primarily this reads information from
# pytorch through RegistrationDeclarations.h and Functions.h. It also allows
# some local overrides (specified in aten_mlir_type.h).
# It generates: aten_mlir_type_defaults.{.cpp,.h} and ATenOps.td, which will need
# to be moved to their appropriate places.
# To run:
# python3 gen_aten_dialect.py --output_folder=. \
# ../csrc/aten_mlir_type.h \
# ${TORCH_INSTALL_PREFIX}/include/ATen/RegistrationDeclarations.h \
# ${TORCH_INSTALL_PREFIX}/include/ATen/Functions.h
def namedtuple_with_defaults(typename, field_names, default_values=()):
ntuple = collections.namedtuple(typename, field_names)
ntuple.__new__.__defaults__ = (None,) * len(ntuple._fields)
if isinstance(default_values, collections.Mapping):
prototype = ntuple(**default_values)
else:
prototype = ntuple(*default_values)
ntuple.__new__.__defaults__ = tuple(prototype)
return ntuple
class ArgTemplate(string.Template):
idpattern = r'[a-z0-9_]+'
FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig')
FuncGen = namedtuple_with_defaults(
'FuncGen',
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig'
)
FuncOpts = namedtuple_with_defaults(
'FuncOpts',
'ref_param, device_param, wparams, outfn_template, outfn_name, shape_check_indices'
)
_GRAMMAR = r"""
start: type fnname "(" params ")"
rtype: "(" rparams ")"
| TNAME
rparams: rparam
| rparam "," rparams
rparam: type param_name
type: CONST? core_type refspec?
fnname: CNAME
refspec: REF
| PTR
core_type: template
| TNAME
template: TNAME "<" typelist ">"
typelist: type
| type "," typelist
REF: "&"
PTR: "*"
CONST: "const"
TNAME: /[a-zA-Z0-9_:]+/
HEXNUMBER: /0x[0-9a-fA-F]+/
params: param
| param "," params
param: type param_name param_defval?
param_name: CNAME
param_defval: "=" init_value
init_value: "true"
| "false"
| "{}"
| NUMBER
| SIGNED_NUMBER
| HEXNUMBER
| ESCAPED_STRING
%import common.CNAME -> CNAME
%import common.NUMBER -> NUMBER
%import common.SIGNED_NUMBER -> SIGNED_NUMBER
%import common.ESCAPED_STRING -> ESCAPED_STRING
%import common.WS
%ignore WS
"""
_PARSER = lark.Lark(_GRAMMAR, parser='lalr', propagate_positions=True)
_XPARSER = lark.Lark(_GRAMMAR,
parser='lalr',
propagate_positions=True,
keep_all_tokens=True)
_TD_BLACKLIST = set([
'clone',
'to',
'copy_',
'copy',
'copy_from',
'_copy_from',
'_unsafe_view',
])
_TD_NO_OPSTATS_LIST = set([
'_log_softmax',
'_log_softmax_backward_data',
])
_FN_BLACKLIST = set([
'numel',
'ones',
'ones_like',
'result_type',
# 'zero_',
'zeros',
'zeros_like',
])
_FN_NO_DEBUG_ENTRY_LIST = set([
'empty',
'fill_',
'zero_',
])
_FN_BLACKLIST_REGEX = [
# ATEN functions
r'[^(]*cudnn',
# XLA/TPU functions
]
_FN_OUT = {
'add_out':
FuncOpts(),
'arange_out(Tensor, Scalar, Scalar, Scalar) -> Tensor':
FuncOpts(outfn_template=ArgTemplate(
'ATenMLIRType::arange($1, $2, $3, $0.options())')),
'bitwise_not_out':
FuncOpts(),
'clamp_out':
FuncOpts(),
'div_out':
FuncOpts(),
'gather_out':
FuncOpts(),
'kthvalue_out':
FuncOpts(),
'index_select_out':
FuncOpts(),
'log_out':
FuncOpts(),
'topk_out':
FuncOpts(),
}
_FN_OUT = {}
# List of tuples with the regex match first, and the corresponding FuncOpts()
# second.
_FN_OUT_REGEX = []
_FN_REMAP = {
'_th_eq(Tensor, Scalar) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::eq'),
'_th_eq(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::eq'),
'_th_ge(Tensor, Scalar) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::ge'),
'_th_ge(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::ge'),
'_th_gt(Tensor, Scalar) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::gt'),
'_th_gt(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::gt'),
'_th_le(Tensor, Scalar) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::le'),
'_th_le(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::le'),
'_th_lt(Tensor, Scalar) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::lt'),
'_th_lt(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::lt'),
'_th_ne(Tensor, Scalar) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::ne'),
'_th_ne(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::ne'),
's__th_and(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::__and__',
shape_check_indices=((0, 1),)),
's__th_or(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::__or__',
shape_check_indices=((0, 1),)),
's__th_xor(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::__xor__',
shape_check_indices=((0, 1),)),
# '_s_where(Tensor, Tensor, Tensor) -> Tensor':
# FuncOpts(
# outfn_name='ATenMLIRType::where',
# shape_check_indices=(
# (0, 1),
# (0, 2),
# )),
's__th_eq(Tensor, Tensor) -> Tensor':
FuncOpts(outfn_name='ATenMLIRType::eq', shape_check_indices=((0, 1),)),
}
_TYPE_NSMAP = {
'Tensor': 'at::Tensor',
'TensorList': 'at::TensorList',
'Scalar': 'at::Scalar',
'Storage': 'at::Storage',
'IntList': 'at::IntList',
'IntArrayRef': 'at::IntArrayRef',
'Generator': 'at::Generator',
'ScalarType': 'at::ScalarType',
'TensorOptions': 'at::TensorOptions',
'SparseTensorRef': 'at::SparseTensorRef',
'Device': 'c10::Device',
'optional': 'c10::optional',
'MemoryFormat': 'at::MemoryFormat',
'QScheme': 'at::QScheme',
'ConstQuantizerPtr': 'at::ConstQuantizerPtr',
'Dimname': 'at::Dimname', # namedtensor-only
'DimnameList': 'at::DimnameList', # namedtensor-only
}
_H_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
#include <ATen/Tensor.h>
namespace torch_mlir {{
class ATenMLIRTypeDefault {{
public:
{hfuncs}
}};
void RegisterAtenTypeFunctions();
}} // namespace torch_mlir
"""
_CPP_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
#include "aten_mlir_type_default.h"
#include <ATen/Context.h>
#include <ATen/Functions.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/CPUGenerator.h>
#include "aten_mlir_bridge.h"
#include "aten_mlir_type.h"
namespace torch_mlir {{
{funcs}
{regs}
}} // namespace torch_mlir
"""
_torch_mlir_FUNCTIONS = {}
_CTOR_FUNCTIONS = {
'empty': '.device(at::DeviceType::CPU)',
'linspace': '.device(at::DeviceType::CPU)',
'logspace': '.device(at::DeviceType::CPU)',
'rand': '.device(at::DeviceType::CPU)',
'rand_like': '.device(at::DeviceType::CPU)',
'randn': '.device(at::DeviceType::CPU)',
'randn_like': '.device(at::DeviceType::CPU)',
'randint': '.device(at::DeviceType::CPU)',
'randint_like': '.device(at::DeviceType::CPU)',
'randperm': '.device(at::DeviceType::CPU)',
'scalar_tensor': '.device(at::DeviceType::CPU)',
}
_FUNCTION_OPTIONS = {
'slice(Tensor, int64_t, int64_t, int64_t, int64_t) -> Tensor':
FuncOpts(wparams=['self']),
}
_RESULT_NAME = 'x_result'
class Context(object):
def __init__(self, functions):
with open(functions, 'r') as ff:
self.functions_data = ff.read()
def get_function(self, name):
if self.functions_data.find(' {}('.format(name)) >= 0:
return 'at::{}'.format(name)
class StringEmit(object):
def __init__(self, sref):
self.sref = sref
self.sval = ''
self.pos = -1
def __repr__(self):
return self.sval
def advance(self, t):
start = t.column - 1
end = t.end_column - 1
pos = self.pos if self.pos >= 0 else start
if start > pos:
self.sval += self.sref[pos:start]
self.sval += t.value
self.pos = end
def skip(self, t):
self.pos = last_match(t) if self.pos >= 0 else -1
def append(self, s):
self.sval += s
self.pos = -1
class TensorFetcher(object):
def __init__(self, var_name):
self.var_name = var_name
self.tvar_name = '{}_tensors'.format(self.var_name)
self.tensors = []
self.writeable = []
def add(self, name, writeable):
if writeable:
self.writeable.append(len(self.tensors))
self.tensors.append(name)
return '{}[{}]'.format(self.var_name, len(self.tensors) - 1)
def generate_fetches(self):
code = ''
code += ' std::vector<at::Tensor> {} = {{{}}};\n'.format(
self.tvar_name, ', '.join(self.tensors))
code += (' auto {} = bridge::MLIRCreateTensorList({});\n').format(
self.var_name, self.tvar_name)
return code
def generate_updates(self):
assert (0)
code = ''
if self.writeable:
ivar_name = '{}_update_indices'.format(self.var_name)
code += ' std::vector<size_t> {} = {{{}}};\n'.format(
ivar_name, ', '.join(str(x) for x in self.writeable))
code += ' bridge::XlaUpdateTensors({}, {}, {});\n'.format(
self.tvar_name, self.var_name, ivar_name)
return code
def list_get(l, n):
return l[n] if n < len(l) else None
def is_blacklisted_fn(fname, mapsig):
if fname in _FN_BLACKLIST or mapsig in _FN_BLACKLIST:
return True
for frx in _FN_BLACKLIST_REGEX:
if re.match(frx, fname) or re.match(frx, mapsig):
return True
return False
def get_outfn_options(fname, mapsig):
for name in [fname, mapsig]:
fnopts = _FN_OUT.get(name, None)
if fnopts is not None:
return fnopts
for frx, fnopts in _FN_OUT_REGEX:
if re.match(frx, fname) or re.match(frx, mapsig):
return fnopts
def get_remapfn_options(fname, mapsig):
for name in [fname, mapsig]:
fnopts = _FN_REMAP.get(name, None)
if fnopts is not None:
return fnopts
def is_write_param(fnopts, pname, defval):
if fnopts and fnopts.wparams:
if pname in fnopts.wparams:
return True
return defval
def first_match(t):
if isinstance(t, lark.lexer.Token):
return t.column - 1
assert isinstance(t, lark.tree.Tree)
return first_match(t.children[0])
def last_match(t):
if isinstance(t, lark.lexer.Token):
return t.end_column - 1
assert isinstance(t, lark.tree.Tree)
return last_match(t.children[-1])
def for_every_token(t, fn):
if isinstance(t, lark.lexer.Token):
fn(t)
else:
assert isinstance(t, lark.tree.Tree)
for c in t.children:
for_every_token(c, fn)
def emit_string(t, emit, emit_fn):
status = emit_fn(t)
if status > 0:
def do_emit(tok):
emit.advance(tok)
for_every_token(t, do_emit)
elif status == 0:
if isinstance(t, lark.lexer.Token):
emit.advance(t)
else:
assert isinstance(t, lark.tree.Tree)
for c in t.children:
emit_string(c, emit, emit_fn)
else:
emit.skip(t)
def typed_child(t, n, ttype):
assert isinstance(t, lark.tree.Tree)
assert n < len(t.children)
c = t.children[n]
assert isinstance(c, lark.tree.Tree)
assert c.data == ttype, t.pretty()
return c
def rewrite_sig(tree, orig_sig, emit_fn=lambda x: 0):
emit = StringEmit(orig_sig)
emit_string(tree, emit, emit_fn)
return str(emit)
def rewrite_signature(sig, tmap):
def rewrite(t):
if t.type == 'TNAME':
new_type = tmap.get(t.value, None)
if new_type is not None:
t.value = new_type
def emit_fn(t):
if isinstance(t, lark.lexer.Token):
return 0
return -1 if t.data == 'param_defval' else 0
xtree = _XPARSER.parse(sig)
for_every_token(xtree, rewrite)
return rewrite_sig(xtree, sig, emit_fn=emit_fn)
def create_stdfunc_sig(tree, orig_sig):
def emit_fn(t):
if isinstance(t, lark.lexer.Token):
return 0
return -1 if t.data == 'param_name' else 0
emit = StringEmit(orig_sig)
# Emit full function return type.
emit_string(typed_child(tree, 0, 'type'), emit, emit_fn)
emit.append('(')
# Emit parameter list w/out parameter names.
emit_string(typed_child(tree, 3, 'params'), emit, emit_fn)
emit.append(')')
return str(emit)
def create_map_sig(tree, orig_sig):
def emit_fn(t):
if isinstance(t, lark.lexer.Token):
return -1 if t.type in ['CONST', 'REF', 'PTR'] else 0
return -1 if t.data in ['param_name', 'param_defval'] else 0
emit = StringEmit(orig_sig)
# Emit full function return type.
emit_string(typed_child(tree, 1, 'fnname'), emit, emit_fn)
emit.append('(')
# Emit parameter list w/out parameter names.
emit_string(typed_child(tree, 3, 'params'), emit, emit_fn)
emit.append(') -> ')
emit_string(typed_child(tree, 0, 'type'), emit, emit_fn)
return str(emit)
def type_core(t):
assert isinstance(t, lark.tree.Tree)
for c in t.children:
if isinstance(c, lark.tree.Tree) and c.data == 'core_type':
c = c.children[0]
if isinstance(c, lark.lexer.Token):
return c.value
assert isinstance(c, lark.tree.Tree) and c.data == 'template'
return c.children[0].value
raise RuntimeError('Not a type tree: {}'.format(t))
def type_is_const(t):
assert isinstance(t, lark.tree.Tree)
c = t.children[0]
return isinstance(c, lark.lexer.Token) and c.value == 'const'
def type_is_refptr(t, kind):
assert isinstance(t, lark.tree.Tree)
c = t.children[-1]
if not isinstance(c, lark.tree.Tree) or c.data != 'refspec':
return False
c = c.children[0]
return isinstance(c, lark.lexer.Token) and c.value == kind
def extract_list(t, l):
assert isinstance(t, lark.tree.Tree)
l.append(t.children[0])
if len(t.children) == 2:
c = t.children[1]
if isinstance(c, lark.tree.Tree) and c.data == t.data:
extract_list(c, l)
return l
def tuple_type_list(t):
assert isinstance(t, lark.tree.Tree)
c = t.children[0]
assert isinstance(c, lark.tree.Tree) and c.data == 'core_type'
c = c.children[0]
assert isinstance(c, lark.tree.Tree) and c.data == 'template'
types = []
return extract_list(c.children[1], types)
def get_function_name(t):
assert isinstance(t, lark.tree.Tree)
fname = t.children[1]
assert isinstance(fname, lark.tree.Tree)
assert fname.data == 'fnname'
return fname.children[0].value
def get_function_signature(t, orig_sig, namefn):
emit = StringEmit(orig_sig)
# Emit full function return type.
emit_string(typed_child(t, 0, 'type'), emit, lambda t: 0)
fnname = typed_child(t, 1, 'fnname').children[0]
xfname = namefn(fnname.value)
emit.append(' {}('.format(xfname))
# Emit parameter list w/out parameter names.
emit_string(typed_child(t, 3, 'params'), emit, lambda t: 0)
emit.append(')')
return str(emit), fnname.value, xfname
def get_parameters(t):
assert isinstance(t, lark.tree.Tree)
c = t.children[2]
assert isinstance(c, lark.tree.Tree)
assert c.data == 'params'
params = []
extract_list(c, params)
return params
def get_rparameters(t):
assert isinstance(t, lark.tree.Tree)
params = []
print(len(t.children))
# c = t.children[3]
# assert isinstance(c, lark.tree.Tree)
# assert c.data == 'rparams'
# extract_list(c, params)
return params
def param_name(t):
assert isinstance(t, lark.tree.Tree)
c = t.children[1]
assert isinstance(c, lark.tree.Tree)
assert c.data == 'param_name'
token = c.children[0]
assert isinstance(token, lark.lexer.Token)
return token.value
def param_type(t):
assert isinstance(t, lark.tree.Tree)
c = t.children[0]
assert isinstance(c, lark.tree.Tree)
return c
def get_optional(fnopts, name, defval=None):
if fnopts is None or not hasattr(fnopts, name):
return defval
return getattr(fnopts, name, defval) or defval
def get_return_value(rtype, rname, param, var, ref_param, fnopts):
crtype = type_core(rtype)
if type_is_const(rtype) or type_is_refptr(rtype, '&'):
# If the return type is a const or a reference, return the matching
# parameter. In these cases we operated on XLA tensors data (the ATEN one),
# but the returned references are the input parameters.
assert param
return param_name(param)
elif crtype != 'Tensor':
return rname
else:
# If instead the return type is a value Tensor, we create a new one by
# wrapping the proper local variable which has been created by calling
# into the CPU tensor implementation.
return 'bridge::CreateMLIRTensor({}, bridge::GetMLIRDevice({}))'.format(
rname, get_optional(fnopts, 'device_param', param_name(ref_param)))
def get_reference_param(params, fnopts=None):
# The reference parameter is the Tensor object which we use to extract the
# result Tensor device, if any.
ref_param = None
other = None
for p in params:
ptype = param_type(p)
cptype = type_core(ptype)
pname = param_name(p)
if get_optional(fnopts, 'ref_param') == pname:
return p
if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'):
other = p
if cptype != 'Tensor':
continue
if not ref_param and (pname == 'self' or type_is_const(ptype)):
ref_param = p
other = p
return ref_param or other
def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param,
fnopts):
types = tuple_type_list(rtype)
retstr = '{}('.format(rtype_str)
for i, ttype in enumerate(types):
if i > 0:
retstr += ', '
tuple_var = 'std::get<{}>({})'.format(i, rname)
retstr += get_return_value(ttype, tuple_var, list_get(params, i),
list_get(param_vars, i), ref_param, fnopts)
return retstr + ')'
def get_return_type_str(t, orig_sig):
assert isinstance(t, lark.tree.Tree)
fname = t.children[1]
assert isinstance(fname, lark.tree.Tree)
assert fname.data == 'fnname'
token = fname.children[0]
assert isinstance(token, lark.lexer.Token)
return orig_sig[0:token.column - 2]
def generate_entry_debug_code(t, fname, params, fname_ns='aten'):
code = ''
if fname in _FN_NO_DEBUG_ENTRY_LIST:
return code
code += ' std::cout << "{}::{}" << std::endl;\n'.format(fname_ns, fname)
# Emits debug code for a given intercepted ATEN type function. For now we use
# a counter which will show up in the metrics reports.
# VLOG info. Use the following to see debug output:
# export TF_CPP_VMODULE=aten_mlir_type_default=3
#code += ' TF_VLOG(3) << "XLA {} :"'.format(fname)
#for p in params:
# ptype = param_type(p)
# cptype = type_core(ptype)
# pname = param_name(p)
# if cptype == 'Tensor':
# code += ' << " {}=" << {}.toString()'.format(pname, pname)
#code += ';\n'
return code
def generate_exit_debug_code(t, fname, rname, params, param_vars):
code = ''
return code
def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars,
ref_param, fnopts):
assert isinstance(t, lark.tree.Tree)
rtype = t.children[0]
ctype = type_core(rtype)
if ctype == 'std::tuple':
retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars,
ref_param, fnopts)
elif ctype == 'std::vector':
#retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format(
# rname, get_optional(fnopts, 'device_param', param_name(ref_param)))
retstr = rname
elif ctype == 'Tensor':
retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param,
fnopts)
elif ctype == 'void' and not type_is_refptr(rtype, '*'):
return ''
else:
retstr = rname
return ' return {};\n'.format(retstr)
def generate_result_assignment(t, rname):
assert isinstance(t, lark.tree.Tree)
rtype = t.children[0]
ctype = type_core(rtype)
if ctype == 'void' and not type_is_refptr(rtype, '*'):
return ''
return 'auto&& {} = '.format(rname)
def get_handling_function(ctx, fname, the_ref_param, param_vars):
function = _torch_mlir_FUNCTIONS.get(fname, None) or ctx.get_function(fname)
if function:
code = '{}({})'.format(function, ', '.join(param_vars))
else:
other_params = list(param_vars)
other_params.remove(the_ref_param)
code = '{}.{}({})'.format(the_ref_param, fname, ', '.join(other_params))
return code
def rewrite_tensor_options(fname, pname):
rw = _CTOR_FUNCTIONS.get(fname, None)
if rw is None:
return '', pname
xname = 'o_{}'.format(pname)
code = ' at::TensorOptions {} = {}{};\n'.format(xname, pname, rw)
return code, xname
def get_param_names(params):
param_vars = []
for p in params:
pname = param_name(p)
param_vars.append(pname)
return param_vars
def expand_fn_template(tmpl, param_vars):
mdict = {}
for i, pname in enumerate(param_vars):
mdict[str(i)] = pname
return tmpl.substitute(mdict)
def create_call(fname, param_vars):
return '{}({})'.format(fname, ', '.join(param_vars))
def generate_shape_checks(param_vars, shape_check_indices, fname):
code = ''
#for i, j in shape_check_indices:
# code += (' XLA_CHECK({}.sizes() == {}.sizes()) << "Operand shapes must be '
# 'identical for {}, mismatch for arguments {} and {}";\n').format(
# param_vars[i], param_vars[j], fname, i + 1, j + 1)
return code
def generate_aten_remap(ctx, fname, sig, params, fnopts):
code = '{} {{\n'.format(sig)
param_vars = get_param_names(params)
if fnopts.outfn_template is not None:
fcall = expand_fn_template(fnopts.outfn_template, param_vars)
else:
assert fnopts.outfn_name
fcall = create_call(fnopts.outfn_name, param_vars)
if fnopts.shape_check_indices is not None:
code += generate_shape_checks(param_vars, fnopts.shape_check_indices, fname)
code += ' return {};\n'.format(fcall)
code += '}'
return code
def generate_outfn_result_copy(dest, src):
return ' {}.unsafeGetTensorImpl()->shallow_copy_from({}.getIntrusivePtr());\n'.format(
dest, src)
def generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
rtype = tree.children[0]
num_outputs = None
if type_core(rtype) == 'std::tuple':
num_outputs = len(tuple_type_list(rtype))
code = '{} {{\n'.format(sig)
code += generate_entry_debug_code(tree, fname, params)
param_vars = get_param_names(params)
if fnopts.outfn_template is not None:
fcall = expand_fn_template(fnopts.outfn_template, param_vars)
else:
m = re.match(r'(.*)_out$', fname)
assert m is not None, fname
out_count = num_outputs if num_outputs is not None else 1
fcall = create_call('ATenMLIRType::{}'.format(m.group(1)),
param_vars[out_count:])
tmp_result = '{}_tmp'.format(fname)
code += ' auto {} = {};\n'.format(tmp_result, fcall)
if num_outputs is None:
code += generate_outfn_result_copy(param_vars[0], tmp_result)
code += generate_exit_debug_code(tree, fname, param_vars[0], params,
param_vars)
code += ' return {};\n'.format(param_vars[0])
else:
for i in range(0, num_outputs):
code += generate_outfn_result_copy(
param_vars[i], 'std::get<{}>({})'.format(i, tmp_result))
code += generate_exit_debug_code(tree, fname, param_vars[0:num_outputs],
params, param_vars)
code += ' return {}('.format(get_return_type_str(rwxtree, rwsig))
for i in range(0, num_outputs):
if i > 0:
code += ', '
code += param_vars[i]
code += ');\n'
code += '}'
return code
def generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig, params,
fnopts):
ref_param = get_reference_param(params, fnopts=fnopts)
code = '{} {{\n'.format(sig)
code += generate_entry_debug_code(tree, fname, params)
the_ref_param = param_name(ref_param) if ref_param else None
tfetcher = TensorFetcher('mlirtens')
param_vars = []
for p in params:
ptype = param_type(p)
cptype = type_core(ptype)
pname = param_name(p)
if cptype == 'TensorList':
#xname = 'l_{}'.format(pname)
#code += (' auto {} = bridge::XlaCreateTensorList({});\n').format(
# xname, pname)
xname = pname
param_vars.append(xname)
elif cptype == 'TensorOptions':
gcode, xname = rewrite_tensor_options(fname, pname)
code += gcode
param_vars.append(xname)
elif cptype != 'Tensor':
param_vars.append(pname)
elif type_is_const(ptype):
xname = tfetcher.add(pname, is_write_param(fnopts, pname, False))
param_vars.append(xname)
else:
xname = tfetcher.add(pname, is_write_param(fnopts, pname, True))
param_vars.append(xname)
if p == ref_param and not get_optional(fnopts, 'ref_param'):
the_ref_param = param_vars[-1]
code += tfetcher.generate_fetches()
result_assign = generate_result_assignment(tree, _RESULT_NAME)
code += ' {}{};\n'.format(
result_assign, get_handling_function(ctx, fname, the_ref_param,
param_vars))
#code += tfetcher.generate_updates()
if result_assign:
code += (' static_cast<void>({}); // Avoid warnings in case not '
'used\n'.format(_RESULT_NAME))
code += generate_exit_debug_code(tree, fname,
_RESULT_NAME if result_assign else None,
params, param_vars)
code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname,
_RESULT_NAME if result_assign else None, params,
param_vars, ref_param, fnopts)
code += '}'
return code
def get_mlir_wrapper(fndef, ctx):
tree = _PARSER.parse(fndef.cpp_sig)
xtree = _XPARSER.parse(fndef.cpp_sig)
mapsig = create_map_sig(xtree, fndef.cpp_sig)
rwsig = rewrite_signature(fndef.cpp_sig, _TYPE_NSMAP)
rwxtree = _XPARSER.parse(rwsig)
params = get_parameters(tree)
fnopts = _FUNCTION_OPTIONS.get(mapsig, None)
def gen_fnname(x):
return 'ATenMLIRTypeDefault::{}'.format(x)
sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname)
if not is_blacklisted_fn(fname, mapsig):
ofnopts = get_outfn_options(fname, mapsig)
rfnopts = get_remapfn_options(fname, mapsig)
if ofnopts is not None:
#print ("gen_aten_out:", fname)
code = generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params,
ofnopts)
elif rfnopts is not None:
#print ("gen_aten_remap", fname)
code = generate_aten_remap(ctx, fname, sig, params, rfnopts)
else:
code = generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig,
params, fnopts)
else:
code = None
return FuncGen(tree=tree,
xtree=xtree,
rwxtree=rwxtree,
func=fname,
xfunc=xfname,
code=code,
sig=fndef.cpp_sig,
rwsig=rwsig,
cppsig=sig,
mapsig=mapsig,
funsig=create_stdfunc_sig(rwxtree, rwsig),
aten_sig=fndef.aten_sig)
def is_tensor_api(fndef):
fndef = fndef.replace('at::', '')
fndef = fndef.replace('c10::Device', 'Device')
m = re.search(r'\bTensor\b', fndef)
return m is not None, fndef
def extract_functions(path):
functions = []
errors = []
for line in open(path, 'r'):
m = re.match(r'\s*([^\s].*); //\s+(.*)', line)
if not m:
continue
fndef = m.group(1)
try:
_XPARSER.parse(fndef)
functions.append(FuncDef(cpp_sig=fndef, aten_sig=m.group(2)))
except Exception as e:
if is_tensor_api(fndef)[0]:
errors.append((fndef, str(e)))
print('Error parsing "{}": {}'.format(fndef, e), file=sys.stderr)
return functions, errors
def get_mapsig_key(mapsig):
# PyTorch generates std::tuple<> without space among the tuple types,
# which would require special understanding in the string rewriter.
# Since we are using this as simple key, we can just string the spaces.
return mapsig.replace(' ', '')
def parse_local_overrides(path):
functions = []
fndef = None
for line in open(path, 'r'):
line = line.strip()
if not fndef:
m = re.match(r'static\s+(.*);', line)
if m:
functions.append(m.group(1))
continue
m = re.match(r'static\s+(.*)', line)
if m:
fndef = m.group(1)
else:
fndef = '{} {}'.format(fndef, line)
if fndef.endswith(';'):
functions.append(fndef[:-1])
fndef = None
assert fndef is None
overrides = {}
for fndef in functions:
# Discard static XLA type functions which are not ATEN.
is_tensor, fndef = is_tensor_api(fndef)
if is_tensor:
xtree = _XPARSER.parse(fndef)
mapsig_key = get_mapsig_key(create_map_sig(xtree, fndef))
overrides[mapsig_key] = fndef
return overrides
def get_dialect_name(func):
name = ''
upper = True
cs = list(func)
for c in cs:
if c == '_':
upper = True
elif upper:
name += str(c).upper()
upper = False
else:
name += c
if cs[-1] == "_":
name += "Under"
return name
def generate_td_functions(fgens, overrides):
code = ''
overridden = set()
code += "#ifdef ATEN_OP_DEFS\n"
code += "#else\n"
code += "#define ATEN_OP_DEFS\n\n"
for fgen in fgens:
mapsig_key = get_mapsig_key(fgen.mapsig)
if mapsig_key in overrides:
overridden.add(mapsig_key)
if fgen.func in _TD_BLACKLIST:
continue
rtype = fgen.tree.children[0]
num_outputs = 1
if type_core(rtype) == 'std::tuple':
num_outputs = len(tuple_type_list(rtype))
#print(num_outputs, rtype)
dialect_name = get_dialect_name(fgen.func)
#print ('"{}"'.format(dialect_name))
code += 'def aten_{}Op: aten_Op<"{}"'.format(dialect_name, fgen.func)
code += ', [NoSideEffect'
if not fgen.func in _TD_NO_OPSTATS_LIST:
code += ', StatisticsOpInterface'
code += ']>,\n'
code += ' Results<(outs'
# foreach output
# rparams = get_rparameters(fgen.tree)
# for p in rparams:
# pname = param_name(p)
# ptype = param_type(p)
# cptype = type_core(ptype)
# print(pname)
code += ' AnyTensor'
for i in range(num_outputs - 1):
code += ', AnyTensor'
code += ')> {\n'
code += ' let arguments = (\n'
params = get_parameters(fgen.tree)
for p in params:
pname = param_name(p)
ptype = param_type(p)
cptype = type_core(ptype)
if (cptype == 'Tensor'):
td_type = "AnyTensor"
elif (cptype == 'Scalar' or cptype == 'int64_t' or cptype == 'double' or
cptype == 'bool'):
td_type = "AnyScalar"
elif (cptype == 'c10::optional' or cptype == 'std::array'):
continue
elif (cptype == 'IntArrayRef'):
td_type = "AnyType"
else:
print('unhandled type', cptype)
td_type = "AnyType"
if p == params[0]:
code += ' ins {}:${}'.format(td_type, pname)
else:
code += ',\n {}:${}'.format(td_type, pname)
code += '\n );\n'
code += ' let summary = "aten {} operator";\n'.format(fgen.func)
code += ' let description = [{\n'
code += ' {}Op\n'.format(dialect_name)
code += ' aten {} operator\n'.format(fgen.func)
code += ' }];\n'
if not fgen.func in _TD_NO_OPSTATS_LIST:
code += ' let extraClassDeclaration = [{\n'
code += ' std::map<std::string, uint64_t> getStatistics();\n'
code += ' }];\n'
code += '}\n\n'
code += "#endif\n"
return code, overridden
def generate_registrations(fgens, overrides):
code = 'void RegisterAtenTypeFunctions() {\n'
code += ' static auto dispatch = torch::RegisterOperators()\n'
overridden = set()
for fgen in fgens:
mapsig_key = get_mapsig_key(fgen.mapsig)
if mapsig_key in overrides:
override_fn = 'ATenMLIRType::{}'.format(fgen.func)
overridden.add(mapsig_key)
else:
override_fn = fgen.xfunc if fgen.code else None
if override_fn:
code += (
' .op(torch::RegisterOperators::options().schema("{}")\n '
'.impl_unboxedOnlyKernel<{}, &{}>(at::TensorTypeId::XLATensorId)\n'
' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format(
fgen.aten_sig, fgen.funsig, override_fn, override_fn,
fgen.aten_sig))
return code + ';\n}\n', overridden
def generate_functions(fgens):
code = ''
for fgen in fgens:
if fgen.code:
code += '{}\n\n'.format(fgen.code)
return code
def generate_class_functions(fgens):
code = ''
for fgen in fgens:
if fgen.code:
code += ' static {};\n'.format(fgen.rwsig)
return code
def gen_output_file(args, name):
if not args.output_folder:
return sys.stdout
return open(os.path.join(args.output_folder, name), 'w')
def gen_h_output_file(args):
return gen_output_file(args, 'aten_mlir_type_default.h')
def gen_cpp_output_file(args):
return gen_output_file(args, 'aten_mlir_type_default.cpp')
def gen_td_output_file(args):
return gen_output_file(args, 'ATenOps.td')
def check_overrides(overrides, overridden):
misses = 0
for mapsig, cpp_sig in overrides.items():
mapsig_key = get_mapsig_key(mapsig)
if not mapsig_key in overridden:
misses += 1
print('ATenMLIRType function missed override: {}; // {}'.format(
cpp_sig, mapsig),
file=sys.stderr)
return misses == 0
def generate(args):
fndefs, errors = extract_functions(args.typedef)
print('Extracted {} functions ({} errors) from {}'.format(
len(fndefs), len(errors), args.typedef),
file=sys.stderr)
assert len(errors) == 0
overrides = parse_local_overrides(args.overridetype)
print('{} function overrides in {}'.format(len(overrides), args.overridetype),
file=sys.stderr)
fgens = []
ctx = Context(args.functions)
for ts in fndefs:
try:
fgen = get_mlir_wrapper(ts, ctx)
if fgen:
fgens.append(fgen)
except Exception as e:
print('Failed to generate wrapper for {}: {}'.format(ts, e),
file=sys.stderr)
print('Generated {} wrappers for {}'.format(len(fgens), args.typedef),
file=sys.stderr)
functions = generate_functions(fgens)
hfunctions = generate_class_functions(fgens)
tdfunctions, overridden = generate_td_functions(fgens, overrides)
assert check_overrides(overrides, overridden)
#print(tdfunctions)
regs, overridden = generate_registrations(fgens, overrides)
#print (len(overrides), len(overridden))
assert check_overrides(overrides, overridden)
# Create output files ...
print(_H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=hfunctions),
file=gen_h_output_file(args))
print(_CPP_HEADER.format(gen=os.path.basename(sys.argv[0]),
funcs=functions,
regs=regs),
file=gen_cpp_output_file(args))
with gen_td_output_file(args) as f:
f.write(tdfunctions)
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--output_folder', type=str)
arg_parser.add_argument('overridetype',
type=str,
metavar='OVERRIDE_TYPE_FILE',
help='The path to the overrides file')
arg_parser.add_argument('typedef',
type=str,
metavar='TYPE_DEFAULT_FILE',
help='The path to the TypeDefault.h file')
arg_parser.add_argument('functions',
type=str,
metavar='FUNCTIONS_FILE',
help='The path to the Functions.h file')
args, files = arg_parser.parse_known_args()
generate(args)