# -*- Python -*- # This file is licensed under a pytorch-style license # See frontends/pytorch/LICENSE for license information. # Structured similarly to code from git@github.com:pytorch/xla.git from __future__ import print_function import argparse import collections import lark import os import re import string import sys #### # This file parses the C++ signatures exported by pytorch and generates # appropriate MLIR operations in a tablegen file. It also generates some of # the more boilerplate parts of the pytorch integration. This may need to be # run if pytorch versions change. Primarily this reads information from # pytorch through RegistrationDeclarations.h and Functions.h. It also allows # some local overrides (specified in aten_mlir_type.h). # It generates: aten_mlir_type_defaults.{.cpp,.h} and ATenOps.td, which will need # to be moved to their appropriate places. # To run: # python3 gen_aten_dialect.py --output_folder=. \ # ../csrc/aten_mlir_type.h \ # ${TORCH_INSTALL_PREFIX}/include/ATen/RegistrationDeclarations.h \ # ${TORCH_INSTALL_PREFIX}/include/ATen/Functions.h def namedtuple_with_defaults(typename, field_names, default_values=()): ntuple = collections.namedtuple(typename, field_names) ntuple.__new__.__defaults__ = (None,) * len(ntuple._fields) if isinstance(default_values, collections.Mapping): prototype = ntuple(**default_values) else: prototype = ntuple(*default_values) ntuple.__new__.__defaults__ = tuple(prototype) return ntuple class ArgTemplate(string.Template): idpattern = r'[a-z0-9_]+' FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig') FuncGen = namedtuple_with_defaults( 'FuncGen', 'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig' ) FuncOpts = namedtuple_with_defaults( 'FuncOpts', 'ref_param, device_param, wparams, outfn_template, outfn_name, shape_check_indices' ) _GRAMMAR = r""" start: type fnname "(" params ")" rtype: "(" rparams ")" | TNAME rparams: rparam | rparam "," rparams rparam: type param_name type: CONST? core_type refspec? fnname: CNAME refspec: REF | PTR core_type: template | TNAME template: TNAME "<" typelist ">" typelist: type | type "," typelist REF: "&" PTR: "*" CONST: "const" TNAME: /[a-zA-Z0-9_:]+/ HEXNUMBER: /0x[0-9a-fA-F]+/ params: param | param "," params param: type param_name param_defval? param_name: CNAME param_defval: "=" init_value init_value: "true" | "false" | "{}" | NUMBER | SIGNED_NUMBER | HEXNUMBER | ESCAPED_STRING %import common.CNAME -> CNAME %import common.NUMBER -> NUMBER %import common.SIGNED_NUMBER -> SIGNED_NUMBER %import common.ESCAPED_STRING -> ESCAPED_STRING %import common.WS %ignore WS """ _PARSER = lark.Lark(_GRAMMAR, parser='lalr', propagate_positions=True) _XPARSER = lark.Lark(_GRAMMAR, parser='lalr', propagate_positions=True, keep_all_tokens=True) _TD_BLACKLIST = set([ 'clone', 'to', 'copy_', 'copy', 'copy_from', '_copy_from', '_unsafe_view', ]) _TD_NO_OPSTATS_LIST = set([ '_log_softmax', '_log_softmax_backward_data', ]) _FN_BLACKLIST = set([ 'numel', 'ones', 'ones_like', 'result_type', # 'zero_', 'zeros', 'zeros_like', ]) _FN_NO_DEBUG_ENTRY_LIST = set([ 'empty', 'fill_', 'zero_', ]) _FN_BLACKLIST_REGEX = [ # ATEN functions r'[^(]*cudnn', # XLA/TPU functions ] _FN_OUT = { 'add_out': FuncOpts(), 'arange_out(Tensor, Scalar, Scalar, Scalar) -> Tensor': FuncOpts(outfn_template=ArgTemplate( 'ATenMLIRType::arange($1, $2, $3, $0.options())')), 'bitwise_not_out': FuncOpts(), 'clamp_out': FuncOpts(), 'div_out': FuncOpts(), 'gather_out': FuncOpts(), 'kthvalue_out': FuncOpts(), 'index_select_out': FuncOpts(), 'log_out': FuncOpts(), 'topk_out': FuncOpts(), } _FN_OUT = {} # List of tuples with the regex match first, and the corresponding FuncOpts() # second. _FN_OUT_REGEX = [] _FN_REMAP = { '_th_eq(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::eq'), '_th_eq(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::eq'), '_th_ge(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::ge'), '_th_ge(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::ge'), '_th_gt(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::gt'), '_th_gt(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::gt'), '_th_le(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::le'), '_th_le(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::le'), '_th_lt(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::lt'), '_th_lt(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::lt'), '_th_ne(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::ne'), '_th_ne(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::ne'), 's__th_and(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::__and__', shape_check_indices=((0, 1),)), 's__th_or(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::__or__', shape_check_indices=((0, 1),)), 's__th_xor(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::__xor__', shape_check_indices=((0, 1),)), # '_s_where(Tensor, Tensor, Tensor) -> Tensor': # FuncOpts( # outfn_name='ATenMLIRType::where', # shape_check_indices=( # (0, 1), # (0, 2), # )), 's__th_eq(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ATenMLIRType::eq', shape_check_indices=((0, 1),)), } _TYPE_NSMAP = { 'Tensor': 'at::Tensor', 'TensorList': 'at::TensorList', 'Scalar': 'at::Scalar', 'Storage': 'at::Storage', 'IntList': 'at::IntList', 'IntArrayRef': 'at::IntArrayRef', 'Generator': 'at::Generator', 'ScalarType': 'at::ScalarType', 'TensorOptions': 'at::TensorOptions', 'SparseTensorRef': 'at::SparseTensorRef', 'Device': 'c10::Device', 'optional': 'c10::optional', 'MemoryFormat': 'at::MemoryFormat', 'QScheme': 'at::QScheme', 'ConstQuantizerPtr': 'at::ConstQuantizerPtr', 'Dimname': 'at::Dimname', # namedtensor-only 'DimnameList': 'at::DimnameList', # namedtensor-only } _H_HEADER = """// Autogenerated file by {gen}. Do not edit directly! #include namespace torch_mlir {{ class ATenMLIRTypeDefault {{ public: {hfuncs} }}; void RegisterAtenTypeFunctions(); }} // namespace torch_mlir """ _CPP_HEADER = """// Autogenerated file by {gen}. Do not edit directly! #include "aten_mlir_type_default.h" #include #include #include #include #include "aten_mlir_bridge.h" #include "aten_mlir_type.h" namespace torch_mlir {{ {funcs} {regs} }} // namespace torch_mlir """ _torch_mlir_FUNCTIONS = {} _CTOR_FUNCTIONS = { 'empty': '.device(at::DeviceType::CPU)', 'linspace': '.device(at::DeviceType::CPU)', 'logspace': '.device(at::DeviceType::CPU)', 'rand': '.device(at::DeviceType::CPU)', 'rand_like': '.device(at::DeviceType::CPU)', 'randn': '.device(at::DeviceType::CPU)', 'randn_like': '.device(at::DeviceType::CPU)', 'randint': '.device(at::DeviceType::CPU)', 'randint_like': '.device(at::DeviceType::CPU)', 'randperm': '.device(at::DeviceType::CPU)', 'scalar_tensor': '.device(at::DeviceType::CPU)', } _FUNCTION_OPTIONS = { 'slice(Tensor, int64_t, int64_t, int64_t, int64_t) -> Tensor': FuncOpts(wparams=['self']), } _RESULT_NAME = 'x_result' class Context(object): def __init__(self, functions): with open(functions, 'r') as ff: self.functions_data = ff.read() def get_function(self, name): if self.functions_data.find(' {}('.format(name)) >= 0: return 'at::{}'.format(name) class StringEmit(object): def __init__(self, sref): self.sref = sref self.sval = '' self.pos = -1 def __repr__(self): return self.sval def advance(self, t): start = t.column - 1 end = t.end_column - 1 pos = self.pos if self.pos >= 0 else start if start > pos: self.sval += self.sref[pos:start] self.sval += t.value self.pos = end def skip(self, t): self.pos = last_match(t) if self.pos >= 0 else -1 def append(self, s): self.sval += s self.pos = -1 class TensorFetcher(object): def __init__(self, var_name): self.var_name = var_name self.tvar_name = '{}_tensors'.format(self.var_name) self.tensors = [] self.writeable = [] def add(self, name, writeable): if writeable: self.writeable.append(len(self.tensors)) self.tensors.append(name) return '{}[{}]'.format(self.var_name, len(self.tensors) - 1) def generate_fetches(self): code = '' code += ' std::vector {} = {{{}}};\n'.format( self.tvar_name, ', '.join(self.tensors)) code += (' auto {} = bridge::MLIRCreateTensorList({});\n').format( self.var_name, self.tvar_name) return code def generate_updates(self): assert (0) code = '' if self.writeable: ivar_name = '{}_update_indices'.format(self.var_name) code += ' std::vector {} = {{{}}};\n'.format( ivar_name, ', '.join(str(x) for x in self.writeable)) code += ' bridge::XlaUpdateTensors({}, {}, {});\n'.format( self.tvar_name, self.var_name, ivar_name) return code def list_get(l, n): return l[n] if n < len(l) else None def is_blacklisted_fn(fname, mapsig): if fname in _FN_BLACKLIST or mapsig in _FN_BLACKLIST: return True for frx in _FN_BLACKLIST_REGEX: if re.match(frx, fname) or re.match(frx, mapsig): return True return False def get_outfn_options(fname, mapsig): for name in [fname, mapsig]: fnopts = _FN_OUT.get(name, None) if fnopts is not None: return fnopts for frx, fnopts in _FN_OUT_REGEX: if re.match(frx, fname) or re.match(frx, mapsig): return fnopts def get_remapfn_options(fname, mapsig): for name in [fname, mapsig]: fnopts = _FN_REMAP.get(name, None) if fnopts is not None: return fnopts def is_write_param(fnopts, pname, defval): if fnopts and fnopts.wparams: if pname in fnopts.wparams: return True return defval def first_match(t): if isinstance(t, lark.lexer.Token): return t.column - 1 assert isinstance(t, lark.tree.Tree) return first_match(t.children[0]) def last_match(t): if isinstance(t, lark.lexer.Token): return t.end_column - 1 assert isinstance(t, lark.tree.Tree) return last_match(t.children[-1]) def for_every_token(t, fn): if isinstance(t, lark.lexer.Token): fn(t) else: assert isinstance(t, lark.tree.Tree) for c in t.children: for_every_token(c, fn) def emit_string(t, emit, emit_fn): status = emit_fn(t) if status > 0: def do_emit(tok): emit.advance(tok) for_every_token(t, do_emit) elif status == 0: if isinstance(t, lark.lexer.Token): emit.advance(t) else: assert isinstance(t, lark.tree.Tree) for c in t.children: emit_string(c, emit, emit_fn) else: emit.skip(t) def typed_child(t, n, ttype): assert isinstance(t, lark.tree.Tree) assert n < len(t.children) c = t.children[n] assert isinstance(c, lark.tree.Tree) assert c.data == ttype, t.pretty() return c def rewrite_sig(tree, orig_sig, emit_fn=lambda x: 0): emit = StringEmit(orig_sig) emit_string(tree, emit, emit_fn) return str(emit) def rewrite_signature(sig, tmap): def rewrite(t): if t.type == 'TNAME': new_type = tmap.get(t.value, None) if new_type is not None: t.value = new_type def emit_fn(t): if isinstance(t, lark.lexer.Token): return 0 return -1 if t.data == 'param_defval' else 0 xtree = _XPARSER.parse(sig) for_every_token(xtree, rewrite) return rewrite_sig(xtree, sig, emit_fn=emit_fn) def create_stdfunc_sig(tree, orig_sig): def emit_fn(t): if isinstance(t, lark.lexer.Token): return 0 return -1 if t.data == 'param_name' else 0 emit = StringEmit(orig_sig) # Emit full function return type. emit_string(typed_child(tree, 0, 'type'), emit, emit_fn) emit.append('(') # Emit parameter list w/out parameter names. emit_string(typed_child(tree, 3, 'params'), emit, emit_fn) emit.append(')') return str(emit) def create_map_sig(tree, orig_sig): def emit_fn(t): if isinstance(t, lark.lexer.Token): return -1 if t.type in ['CONST', 'REF', 'PTR'] else 0 return -1 if t.data in ['param_name', 'param_defval'] else 0 emit = StringEmit(orig_sig) # Emit full function return type. emit_string(typed_child(tree, 1, 'fnname'), emit, emit_fn) emit.append('(') # Emit parameter list w/out parameter names. emit_string(typed_child(tree, 3, 'params'), emit, emit_fn) emit.append(') -> ') emit_string(typed_child(tree, 0, 'type'), emit, emit_fn) return str(emit) def type_core(t): assert isinstance(t, lark.tree.Tree) for c in t.children: if isinstance(c, lark.tree.Tree) and c.data == 'core_type': c = c.children[0] if isinstance(c, lark.lexer.Token): return c.value assert isinstance(c, lark.tree.Tree) and c.data == 'template' return c.children[0].value raise RuntimeError('Not a type tree: {}'.format(t)) def type_is_const(t): assert isinstance(t, lark.tree.Tree) c = t.children[0] return isinstance(c, lark.lexer.Token) and c.value == 'const' def type_is_refptr(t, kind): assert isinstance(t, lark.tree.Tree) c = t.children[-1] if not isinstance(c, lark.tree.Tree) or c.data != 'refspec': return False c = c.children[0] return isinstance(c, lark.lexer.Token) and c.value == kind def extract_list(t, l): assert isinstance(t, lark.tree.Tree) l.append(t.children[0]) if len(t.children) == 2: c = t.children[1] if isinstance(c, lark.tree.Tree) and c.data == t.data: extract_list(c, l) return l def tuple_type_list(t): assert isinstance(t, lark.tree.Tree) c = t.children[0] assert isinstance(c, lark.tree.Tree) and c.data == 'core_type' c = c.children[0] assert isinstance(c, lark.tree.Tree) and c.data == 'template' types = [] return extract_list(c.children[1], types) def get_function_name(t): assert isinstance(t, lark.tree.Tree) fname = t.children[1] assert isinstance(fname, lark.tree.Tree) assert fname.data == 'fnname' return fname.children[0].value def get_function_signature(t, orig_sig, namefn): emit = StringEmit(orig_sig) # Emit full function return type. emit_string(typed_child(t, 0, 'type'), emit, lambda t: 0) fnname = typed_child(t, 1, 'fnname').children[0] xfname = namefn(fnname.value) emit.append(' {}('.format(xfname)) # Emit parameter list w/out parameter names. emit_string(typed_child(t, 3, 'params'), emit, lambda t: 0) emit.append(')') return str(emit), fnname.value, xfname def get_parameters(t): assert isinstance(t, lark.tree.Tree) c = t.children[2] assert isinstance(c, lark.tree.Tree) assert c.data == 'params' params = [] extract_list(c, params) return params def get_rparameters(t): assert isinstance(t, lark.tree.Tree) params = [] print(len(t.children)) # c = t.children[3] # assert isinstance(c, lark.tree.Tree) # assert c.data == 'rparams' # extract_list(c, params) return params def param_name(t): assert isinstance(t, lark.tree.Tree) c = t.children[1] assert isinstance(c, lark.tree.Tree) assert c.data == 'param_name' token = c.children[0] assert isinstance(token, lark.lexer.Token) return token.value def param_type(t): assert isinstance(t, lark.tree.Tree) c = t.children[0] assert isinstance(c, lark.tree.Tree) return c def get_optional(fnopts, name, defval=None): if fnopts is None or not hasattr(fnopts, name): return defval return getattr(fnopts, name, defval) or defval def get_return_value(rtype, rname, param, var, ref_param, fnopts): crtype = type_core(rtype) if type_is_const(rtype) or type_is_refptr(rtype, '&'): # If the return type is a const or a reference, return the matching # parameter. In these cases we operated on XLA tensors data (the ATEN one), # but the returned references are the input parameters. assert param return param_name(param) elif crtype != 'Tensor': return rname else: # If instead the return type is a value Tensor, we create a new one by # wrapping the proper local variable which has been created by calling # into the CPU tensor implementation. return 'bridge::CreateMLIRTensor({}, bridge::GetMLIRDevice({}))'.format( rname, get_optional(fnopts, 'device_param', param_name(ref_param))) def get_reference_param(params, fnopts=None): # The reference parameter is the Tensor object which we use to extract the # result Tensor device, if any. ref_param = None other = None for p in params: ptype = param_type(p) cptype = type_core(ptype) pname = param_name(p) if get_optional(fnopts, 'ref_param') == pname: return p if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'): other = p if cptype != 'Tensor': continue if not ref_param and (pname == 'self' or type_is_const(ptype)): ref_param = p other = p return ref_param or other def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param, fnopts): types = tuple_type_list(rtype) retstr = '{}('.format(rtype_str) for i, ttype in enumerate(types): if i > 0: retstr += ', ' tuple_var = 'std::get<{}>({})'.format(i, rname) retstr += get_return_value(ttype, tuple_var, list_get(params, i), list_get(param_vars, i), ref_param, fnopts) return retstr + ')' def get_return_type_str(t, orig_sig): assert isinstance(t, lark.tree.Tree) fname = t.children[1] assert isinstance(fname, lark.tree.Tree) assert fname.data == 'fnname' token = fname.children[0] assert isinstance(token, lark.lexer.Token) return orig_sig[0:token.column - 2] def generate_entry_debug_code(t, fname, params, fname_ns='aten'): code = '' if fname in _FN_NO_DEBUG_ENTRY_LIST: return code code += ' std::cout << "{}::{}" << std::endl;\n'.format(fname_ns, fname) # Emits debug code for a given intercepted ATEN type function. For now we use # a counter which will show up in the metrics reports. # VLOG info. Use the following to see debug output: # export TF_CPP_VMODULE=aten_mlir_type_default=3 #code += ' TF_VLOG(3) << "XLA {} :"'.format(fname) #for p in params: # ptype = param_type(p) # cptype = type_core(ptype) # pname = param_name(p) # if cptype == 'Tensor': # code += ' << " {}=" << {}.toString()'.format(pname, pname) #code += ';\n' return code def generate_exit_debug_code(t, fname, rname, params, param_vars): code = '' return code def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars, ref_param, fnopts): assert isinstance(t, lark.tree.Tree) rtype = t.children[0] ctype = type_core(rtype) if ctype == 'std::tuple': retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param, fnopts) elif ctype == 'std::vector': #retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format( # rname, get_optional(fnopts, 'device_param', param_name(ref_param))) retstr = rname elif ctype == 'Tensor': retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param, fnopts) elif ctype == 'void' and not type_is_refptr(rtype, '*'): return '' else: retstr = rname return ' return {};\n'.format(retstr) def generate_result_assignment(t, rname): assert isinstance(t, lark.tree.Tree) rtype = t.children[0] ctype = type_core(rtype) if ctype == 'void' and not type_is_refptr(rtype, '*'): return '' return 'auto&& {} = '.format(rname) def get_handling_function(ctx, fname, the_ref_param, param_vars): function = _torch_mlir_FUNCTIONS.get(fname, None) or ctx.get_function(fname) if function: code = '{}({})'.format(function, ', '.join(param_vars)) else: other_params = list(param_vars) other_params.remove(the_ref_param) code = '{}.{}({})'.format(the_ref_param, fname, ', '.join(other_params)) return code def rewrite_tensor_options(fname, pname): rw = _CTOR_FUNCTIONS.get(fname, None) if rw is None: return '', pname xname = 'o_{}'.format(pname) code = ' at::TensorOptions {} = {}{};\n'.format(xname, pname, rw) return code, xname def get_param_names(params): param_vars = [] for p in params: pname = param_name(p) param_vars.append(pname) return param_vars def expand_fn_template(tmpl, param_vars): mdict = {} for i, pname in enumerate(param_vars): mdict[str(i)] = pname return tmpl.substitute(mdict) def create_call(fname, param_vars): return '{}({})'.format(fname, ', '.join(param_vars)) def generate_shape_checks(param_vars, shape_check_indices, fname): code = '' #for i, j in shape_check_indices: # code += (' XLA_CHECK({}.sizes() == {}.sizes()) << "Operand shapes must be ' # 'identical for {}, mismatch for arguments {} and {}";\n').format( # param_vars[i], param_vars[j], fname, i + 1, j + 1) return code def generate_aten_remap(ctx, fname, sig, params, fnopts): code = '{} {{\n'.format(sig) param_vars = get_param_names(params) if fnopts.outfn_template is not None: fcall = expand_fn_template(fnopts.outfn_template, param_vars) else: assert fnopts.outfn_name fcall = create_call(fnopts.outfn_name, param_vars) if fnopts.shape_check_indices is not None: code += generate_shape_checks(param_vars, fnopts.shape_check_indices, fname) code += ' return {};\n'.format(fcall) code += '}' return code def generate_outfn_result_copy(dest, src): return ' {}.unsafeGetTensorImpl()->shallow_copy_from({}.getIntrusivePtr());\n'.format( dest, src) def generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts): rtype = tree.children[0] num_outputs = None if type_core(rtype) == 'std::tuple': num_outputs = len(tuple_type_list(rtype)) code = '{} {{\n'.format(sig) code += generate_entry_debug_code(tree, fname, params) param_vars = get_param_names(params) if fnopts.outfn_template is not None: fcall = expand_fn_template(fnopts.outfn_template, param_vars) else: m = re.match(r'(.*)_out$', fname) assert m is not None, fname out_count = num_outputs if num_outputs is not None else 1 fcall = create_call('ATenMLIRType::{}'.format(m.group(1)), param_vars[out_count:]) tmp_result = '{}_tmp'.format(fname) code += ' auto {} = {};\n'.format(tmp_result, fcall) if num_outputs is None: code += generate_outfn_result_copy(param_vars[0], tmp_result) code += generate_exit_debug_code(tree, fname, param_vars[0], params, param_vars) code += ' return {};\n'.format(param_vars[0]) else: for i in range(0, num_outputs): code += generate_outfn_result_copy( param_vars[i], 'std::get<{}>({})'.format(i, tmp_result)) code += generate_exit_debug_code(tree, fname, param_vars[0:num_outputs], params, param_vars) code += ' return {}('.format(get_return_type_str(rwxtree, rwsig)) for i in range(0, num_outputs): if i > 0: code += ', ' code += param_vars[i] code += ');\n' code += '}' return code def generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts): ref_param = get_reference_param(params, fnopts=fnopts) code = '{} {{\n'.format(sig) code += generate_entry_debug_code(tree, fname, params) the_ref_param = param_name(ref_param) if ref_param else None tfetcher = TensorFetcher('mlirtens') param_vars = [] for p in params: ptype = param_type(p) cptype = type_core(ptype) pname = param_name(p) if cptype == 'TensorList': #xname = 'l_{}'.format(pname) #code += (' auto {} = bridge::XlaCreateTensorList({});\n').format( # xname, pname) xname = pname param_vars.append(xname) elif cptype == 'TensorOptions': gcode, xname = rewrite_tensor_options(fname, pname) code += gcode param_vars.append(xname) elif cptype != 'Tensor': param_vars.append(pname) elif type_is_const(ptype): xname = tfetcher.add(pname, is_write_param(fnopts, pname, False)) param_vars.append(xname) else: xname = tfetcher.add(pname, is_write_param(fnopts, pname, True)) param_vars.append(xname) if p == ref_param and not get_optional(fnopts, 'ref_param'): the_ref_param = param_vars[-1] code += tfetcher.generate_fetches() result_assign = generate_result_assignment(tree, _RESULT_NAME) code += ' {}{};\n'.format( result_assign, get_handling_function(ctx, fname, the_ref_param, param_vars)) #code += tfetcher.generate_updates() if result_assign: code += (' static_cast({}); // Avoid warnings in case not ' 'used\n'.format(_RESULT_NAME)) code += generate_exit_debug_code(tree, fname, _RESULT_NAME if result_assign else None, params, param_vars) code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname, _RESULT_NAME if result_assign else None, params, param_vars, ref_param, fnopts) code += '}' return code def get_mlir_wrapper(fndef, ctx): tree = _PARSER.parse(fndef.cpp_sig) xtree = _XPARSER.parse(fndef.cpp_sig) mapsig = create_map_sig(xtree, fndef.cpp_sig) rwsig = rewrite_signature(fndef.cpp_sig, _TYPE_NSMAP) rwxtree = _XPARSER.parse(rwsig) params = get_parameters(tree) fnopts = _FUNCTION_OPTIONS.get(mapsig, None) def gen_fnname(x): return 'ATenMLIRTypeDefault::{}'.format(x) sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname) if not is_blacklisted_fn(fname, mapsig): ofnopts = get_outfn_options(fname, mapsig) rfnopts = get_remapfn_options(fname, mapsig) if ofnopts is not None: #print ("gen_aten_out:", fname) code = generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, ofnopts) elif rfnopts is not None: #print ("gen_aten_remap", fname) code = generate_aten_remap(ctx, fname, sig, params, rfnopts) else: code = generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts) else: code = None return FuncGen(tree=tree, xtree=xtree, rwxtree=rwxtree, func=fname, xfunc=xfname, code=code, sig=fndef.cpp_sig, rwsig=rwsig, cppsig=sig, mapsig=mapsig, funsig=create_stdfunc_sig(rwxtree, rwsig), aten_sig=fndef.aten_sig) def is_tensor_api(fndef): fndef = fndef.replace('at::', '') fndef = fndef.replace('c10::Device', 'Device') m = re.search(r'\bTensor\b', fndef) return m is not None, fndef def extract_functions(path): functions = [] errors = [] for line in open(path, 'r'): m = re.match(r'\s*([^\s].*); //\s+(.*)', line) if not m: continue fndef = m.group(1) try: _XPARSER.parse(fndef) functions.append(FuncDef(cpp_sig=fndef, aten_sig=m.group(2))) except Exception as e: if is_tensor_api(fndef)[0]: errors.append((fndef, str(e))) print('Error parsing "{}": {}'.format(fndef, e), file=sys.stderr) return functions, errors def get_mapsig_key(mapsig): # PyTorch generates std::tuple<> without space among the tuple types, # which would require special understanding in the string rewriter. # Since we are using this as simple key, we can just string the spaces. return mapsig.replace(' ', '') def parse_local_overrides(path): functions = [] fndef = None for line in open(path, 'r'): line = line.strip() if not fndef: m = re.match(r'static\s+(.*);', line) if m: functions.append(m.group(1)) continue m = re.match(r'static\s+(.*)', line) if m: fndef = m.group(1) else: fndef = '{} {}'.format(fndef, line) if fndef.endswith(';'): functions.append(fndef[:-1]) fndef = None assert fndef is None overrides = {} for fndef in functions: # Discard static XLA type functions which are not ATEN. is_tensor, fndef = is_tensor_api(fndef) if is_tensor: xtree = _XPARSER.parse(fndef) mapsig_key = get_mapsig_key(create_map_sig(xtree, fndef)) overrides[mapsig_key] = fndef return overrides def get_dialect_name(func): name = '' upper = True cs = list(func) for c in cs: if c == '_': upper = True elif upper: name += str(c).upper() upper = False else: name += c if cs[-1] == "_": name += "Under" return name def generate_td_functions(fgens, overrides): code = '' overridden = set() code += "#ifdef ATEN_OP_DEFS\n" code += "#else\n" code += "#define ATEN_OP_DEFS\n\n" for fgen in fgens: mapsig_key = get_mapsig_key(fgen.mapsig) if mapsig_key in overrides: overridden.add(mapsig_key) if fgen.func in _TD_BLACKLIST: continue rtype = fgen.tree.children[0] num_outputs = 1 if type_core(rtype) == 'std::tuple': num_outputs = len(tuple_type_list(rtype)) #print(num_outputs, rtype) dialect_name = get_dialect_name(fgen.func) #print ('"{}"'.format(dialect_name)) code += 'def aten_{}Op: aten_Op<"{}"'.format(dialect_name, fgen.func) code += ', [NoSideEffect' if not fgen.func in _TD_NO_OPSTATS_LIST: code += ', StatisticsOpInterface' code += ']>,\n' code += ' Results<(outs' # foreach output # rparams = get_rparameters(fgen.tree) # for p in rparams: # pname = param_name(p) # ptype = param_type(p) # cptype = type_core(ptype) # print(pname) code += ' AnyTensor' for i in range(num_outputs - 1): code += ', AnyTensor' code += ')> {\n' code += ' let arguments = (\n' params = get_parameters(fgen.tree) for p in params: pname = param_name(p) ptype = param_type(p) cptype = type_core(ptype) if (cptype == 'Tensor'): td_type = "AnyTensor" elif (cptype == 'Scalar' or cptype == 'int64_t' or cptype == 'double' or cptype == 'bool'): td_type = "AnyScalar" elif (cptype == 'c10::optional' or cptype == 'std::array'): continue elif (cptype == 'IntArrayRef'): td_type = "AnyType" else: print('unhandled type', cptype) td_type = "AnyType" if p == params[0]: code += ' ins {}:${}'.format(td_type, pname) else: code += ',\n {}:${}'.format(td_type, pname) code += '\n );\n' code += ' let summary = "aten {} operator";\n'.format(fgen.func) code += ' let description = [{\n' code += ' {}Op\n'.format(dialect_name) code += ' aten {} operator\n'.format(fgen.func) code += ' }];\n' if not fgen.func in _TD_NO_OPSTATS_LIST: code += ' let extraClassDeclaration = [{\n' code += ' std::map getStatistics();\n' code += ' }];\n' code += '}\n\n' code += "#endif\n" return code, overridden def generate_registrations(fgens, overrides): code = 'void RegisterAtenTypeFunctions() {\n' code += ' static auto dispatch = torch::RegisterOperators()\n' overridden = set() for fgen in fgens: mapsig_key = get_mapsig_key(fgen.mapsig) if mapsig_key in overrides: override_fn = 'ATenMLIRType::{}'.format(fgen.func) overridden.add(mapsig_key) else: override_fn = fgen.xfunc if fgen.code else None if override_fn: code += ( ' .op(torch::RegisterOperators::options().schema("{}")\n ' '.impl_unboxedOnlyKernel<{}, &{}>(at::TensorTypeId::XLATensorId)\n' ' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format( fgen.aten_sig, fgen.funsig, override_fn, override_fn, fgen.aten_sig)) return code + ';\n}\n', overridden def generate_functions(fgens): code = '' for fgen in fgens: if fgen.code: code += '{}\n\n'.format(fgen.code) return code def generate_class_functions(fgens): code = '' for fgen in fgens: if fgen.code: code += ' static {};\n'.format(fgen.rwsig) return code def gen_output_file(args, name): if not args.output_folder: return sys.stdout return open(os.path.join(args.output_folder, name), 'w') def gen_h_output_file(args): return gen_output_file(args, 'aten_mlir_type_default.h') def gen_cpp_output_file(args): return gen_output_file(args, 'aten_mlir_type_default.cpp') def gen_td_output_file(args): return gen_output_file(args, 'ATenOps.td') def check_overrides(availagle_fgens, overrides, overridden): misses = 0 for mapsig, cpp_sig in overrides.items(): mapsig_key = get_mapsig_key(mapsig) if not mapsig_key in overridden: misses += 1 print('ERROR: ATenMLIRType function missed override:\n' ' CPPSIG: {}\n' ' MAPSIG: {}\n' ' KEY : {}\n'.format(cpp_sig, mapsig, mapsig_key), file=sys.stderr) if misses != 0: print('Some required overrides were missing (see above).') print('Available overrides:') for fgen in availagle_fgens: print(' ', get_mapsig_key(fgen.mapsig)) return misses == 0 def generate(args): fndefs, errors = extract_functions(args.typedef) print('Extracted {} functions ({} errors) from {}'.format( len(fndefs), len(errors), args.typedef), file=sys.stderr) assert len(errors) == 0 local_overrides = parse_local_overrides(args.overridetype) print('{} function overrides in {}'.format(len(local_overrides), args.overridetype), file=sys.stderr) fgens = [] ctx = Context(args.functions) for ts in fndefs: try: fgen = get_mlir_wrapper(ts, ctx) if fgen: fgens.append(fgen) except Exception as e: print('Failed to generate wrapper for {}: {}'.format(ts, e), file=sys.stderr) print('Generated {} wrappers for {}'.format(len(fgens), args.typedef), file=sys.stderr) functions = generate_functions(fgens) hfunctions = generate_class_functions(fgens) tdfunctions, overridden = generate_td_functions(fgens, local_overrides) assert check_overrides( fgens, local_overrides, overridden), ('Missing overrides when generating td functions') #print(tdfunctions) regs, overridden = generate_registrations(fgens, local_overrides) #print (len(overrides), len(overridden)) assert check_overrides( fgens, local_overrides, overridden), ('Missing local overrides when generating registrations') # Create output files ... print(_H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=hfunctions), file=gen_h_output_file(args)) print(_CPP_HEADER.format(gen=os.path.basename(sys.argv[0]), funcs=functions, regs=regs), file=gen_cpp_output_file(args)) with gen_td_output_file(args) as f: f.write(tdfunctions) if __name__ == '__main__': arg_parser = argparse.ArgumentParser() arg_parser.add_argument('--output_folder', type=str) arg_parser.add_argument('overridetype', type=str, metavar='OVERRIDE_TYPE_FILE', help='The path to the overrides file') arg_parser.add_argument('typedef', type=str, metavar='TYPE_DEFAULT_FILE', help='The path to the TypeDefault.h file') arg_parser.add_argument('functions', type=str, metavar='FUNCTIONS_FILE', help='The path to the Functions.h file') args, files = arg_parser.parse_known_args() generate(args)