mirror of https://github.com/llvm/torch-mlir
1245 lines
36 KiB
Python
1245 lines
36 KiB
Python
|
# -*- Python -*-
|
||
|
# This file is licensed under a pytorch-style license
|
||
|
# See frontends/pytorch/LICENSE for license information.
|
||
|
|
||
|
# Structured similarly to code from git@github.com:pytorch/xla.git
|
||
|
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import argparse
|
||
|
import collections
|
||
|
import lark
|
||
|
import os
|
||
|
import re
|
||
|
import string
|
||
|
import sys
|
||
|
|
||
|
####
|
||
|
# This file parses the C++ signatures exported by pytorch and generates
|
||
|
# appropriate MLIR operations in a tablegen file. It also generates some of
|
||
|
# the more boilerplate parts of the pytorch integration. This may need to be
|
||
|
# run if pytorch versions change. Primarily this reads information from
|
||
|
# pytorch through RegistrationDeclarations.h and Functions.h. It also allows
|
||
|
# some local overrides (specified in aten_mlir_type.h).
|
||
|
# It generates: aten_mlir_type_defaults.{.cpp,.h} and ATenOps.td, which will need
|
||
|
# to be moved to their appropriate places.
|
||
|
|
||
|
# To run:
|
||
|
# python3 gen_aten_dialect.py --output_folder=. \
|
||
|
# ../csrc/aten_mlir_type.h \
|
||
|
# ${TORCH_INSTALL_PREFIX}/include/ATen/RegistrationDeclarations.h \
|
||
|
# ${TORCH_INSTALL_PREFIX}/include/ATen/Functions.h
|
||
|
|
||
|
|
||
|
def namedtuple_with_defaults(typename, field_names, default_values=()):
|
||
|
ntuple = collections.namedtuple(typename, field_names)
|
||
|
ntuple.__new__.__defaults__ = (None,) * len(ntuple._fields)
|
||
|
if isinstance(default_values, collections.Mapping):
|
||
|
prototype = ntuple(**default_values)
|
||
|
else:
|
||
|
prototype = ntuple(*default_values)
|
||
|
ntuple.__new__.__defaults__ = tuple(prototype)
|
||
|
return ntuple
|
||
|
|
||
|
|
||
|
class ArgTemplate(string.Template):
|
||
|
idpattern = r'[a-z0-9_]+'
|
||
|
|
||
|
|
||
|
FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig')
|
||
|
|
||
|
FuncGen = namedtuple_with_defaults(
|
||
|
'FuncGen',
|
||
|
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig'
|
||
|
)
|
||
|
|
||
|
FuncOpts = namedtuple_with_defaults(
|
||
|
'FuncOpts',
|
||
|
'ref_param, device_param, wparams, outfn_template, outfn_name, shape_check_indices'
|
||
|
)
|
||
|
|
||
|
_GRAMMAR = r"""
|
||
|
start: type fnname "(" params ")"
|
||
|
rtype: "(" rparams ")"
|
||
|
| TNAME
|
||
|
rparams: rparam
|
||
|
| rparam "," rparams
|
||
|
rparam: type param_name
|
||
|
type: CONST? core_type refspec?
|
||
|
fnname: CNAME
|
||
|
refspec: REF
|
||
|
| PTR
|
||
|
core_type: template
|
||
|
| TNAME
|
||
|
template: TNAME "<" typelist ">"
|
||
|
typelist: type
|
||
|
| type "," typelist
|
||
|
REF: "&"
|
||
|
PTR: "*"
|
||
|
CONST: "const"
|
||
|
TNAME: /[a-zA-Z0-9_:]+/
|
||
|
HEXNUMBER: /0x[0-9a-fA-F]+/
|
||
|
params: param
|
||
|
| param "," params
|
||
|
param: type param_name param_defval?
|
||
|
param_name: CNAME
|
||
|
|
||
|
param_defval: "=" init_value
|
||
|
init_value: "true"
|
||
|
| "false"
|
||
|
| "{}"
|
||
|
| NUMBER
|
||
|
| SIGNED_NUMBER
|
||
|
| HEXNUMBER
|
||
|
| ESCAPED_STRING
|
||
|
|
||
|
%import common.CNAME -> CNAME
|
||
|
%import common.NUMBER -> NUMBER
|
||
|
%import common.SIGNED_NUMBER -> SIGNED_NUMBER
|
||
|
%import common.ESCAPED_STRING -> ESCAPED_STRING
|
||
|
%import common.WS
|
||
|
%ignore WS
|
||
|
"""
|
||
|
|
||
|
_PARSER = lark.Lark(_GRAMMAR, parser='lalr', propagate_positions=True)
|
||
|
|
||
|
_XPARSER = lark.Lark(_GRAMMAR,
|
||
|
parser='lalr',
|
||
|
propagate_positions=True,
|
||
|
keep_all_tokens=True)
|
||
|
|
||
|
_TD_BLACKLIST = set([
|
||
|
'clone',
|
||
|
'to',
|
||
|
'copy_',
|
||
|
'copy',
|
||
|
'copy_from',
|
||
|
'_copy_from',
|
||
|
'_unsafe_view',
|
||
|
])
|
||
|
|
||
|
_TD_NO_OPSTATS_LIST = set([
|
||
|
'_log_softmax',
|
||
|
'_log_softmax_backward_data',
|
||
|
])
|
||
|
|
||
|
_FN_BLACKLIST = set([
|
||
|
'numel',
|
||
|
'ones',
|
||
|
'ones_like',
|
||
|
'result_type',
|
||
|
# 'zero_',
|
||
|
'zeros',
|
||
|
'zeros_like',
|
||
|
])
|
||
|
|
||
|
_FN_NO_DEBUG_ENTRY_LIST = set([
|
||
|
'empty',
|
||
|
'fill_',
|
||
|
'zero_',
|
||
|
])
|
||
|
|
||
|
_FN_BLACKLIST_REGEX = [
|
||
|
# ATEN functions
|
||
|
r'[^(]*cudnn',
|
||
|
# XLA/TPU functions
|
||
|
]
|
||
|
|
||
|
_FN_OUT = {
|
||
|
'add_out':
|
||
|
FuncOpts(),
|
||
|
'arange_out(Tensor, Scalar, Scalar, Scalar) -> Tensor':
|
||
|
FuncOpts(outfn_template=ArgTemplate(
|
||
|
'ATenMLIRType::arange($1, $2, $3, $0.options())')),
|
||
|
'bitwise_not_out':
|
||
|
FuncOpts(),
|
||
|
'clamp_out':
|
||
|
FuncOpts(),
|
||
|
'div_out':
|
||
|
FuncOpts(),
|
||
|
'gather_out':
|
||
|
FuncOpts(),
|
||
|
'kthvalue_out':
|
||
|
FuncOpts(),
|
||
|
'index_select_out':
|
||
|
FuncOpts(),
|
||
|
'log_out':
|
||
|
FuncOpts(),
|
||
|
'topk_out':
|
||
|
FuncOpts(),
|
||
|
}
|
||
|
_FN_OUT = {}
|
||
|
|
||
|
# List of tuples with the regex match first, and the corresponding FuncOpts()
|
||
|
# second.
|
||
|
_FN_OUT_REGEX = []
|
||
|
|
||
|
_FN_REMAP = {
|
||
|
'_th_eq(Tensor, Scalar) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::eq'),
|
||
|
'_th_eq(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::eq'),
|
||
|
'_th_ge(Tensor, Scalar) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::ge'),
|
||
|
'_th_ge(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::ge'),
|
||
|
'_th_gt(Tensor, Scalar) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::gt'),
|
||
|
'_th_gt(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::gt'),
|
||
|
'_th_le(Tensor, Scalar) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::le'),
|
||
|
'_th_le(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::le'),
|
||
|
'_th_lt(Tensor, Scalar) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::lt'),
|
||
|
'_th_lt(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::lt'),
|
||
|
'_th_ne(Tensor, Scalar) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::ne'),
|
||
|
'_th_ne(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::ne'),
|
||
|
's__th_and(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::__and__',
|
||
|
shape_check_indices=((0, 1),)),
|
||
|
's__th_or(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::__or__',
|
||
|
shape_check_indices=((0, 1),)),
|
||
|
's__th_xor(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::__xor__',
|
||
|
shape_check_indices=((0, 1),)),
|
||
|
# '_s_where(Tensor, Tensor, Tensor) -> Tensor':
|
||
|
# FuncOpts(
|
||
|
# outfn_name='ATenMLIRType::where',
|
||
|
# shape_check_indices=(
|
||
|
# (0, 1),
|
||
|
# (0, 2),
|
||
|
# )),
|
||
|
's__th_eq(Tensor, Tensor) -> Tensor':
|
||
|
FuncOpts(outfn_name='ATenMLIRType::eq', shape_check_indices=((0, 1),)),
|
||
|
}
|
||
|
|
||
|
_TYPE_NSMAP = {
|
||
|
'Tensor': 'at::Tensor',
|
||
|
'TensorList': 'at::TensorList',
|
||
|
'Scalar': 'at::Scalar',
|
||
|
'Storage': 'at::Storage',
|
||
|
'IntList': 'at::IntList',
|
||
|
'IntArrayRef': 'at::IntArrayRef',
|
||
|
'Generator': 'at::Generator',
|
||
|
'ScalarType': 'at::ScalarType',
|
||
|
'TensorOptions': 'at::TensorOptions',
|
||
|
'SparseTensorRef': 'at::SparseTensorRef',
|
||
|
'Device': 'c10::Device',
|
||
|
'optional': 'c10::optional',
|
||
|
'MemoryFormat': 'at::MemoryFormat',
|
||
|
'QScheme': 'at::QScheme',
|
||
|
'ConstQuantizerPtr': 'at::ConstQuantizerPtr',
|
||
|
'Dimname': 'at::Dimname', # namedtensor-only
|
||
|
'DimnameList': 'at::DimnameList', # namedtensor-only
|
||
|
}
|
||
|
|
||
|
_H_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
|
||
|
|
||
|
#include <ATen/Tensor.h>
|
||
|
|
||
|
namespace torch_mlir {{
|
||
|
|
||
|
class ATenMLIRTypeDefault {{
|
||
|
public:
|
||
|
{hfuncs}
|
||
|
}};
|
||
|
|
||
|
void RegisterAtenTypeFunctions();
|
||
|
|
||
|
}} // namespace torch_mlir
|
||
|
"""
|
||
|
|
||
|
_CPP_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
|
||
|
#include "aten_mlir_type_default.h"
|
||
|
|
||
|
#include <ATen/Context.h>
|
||
|
#include <ATen/Functions.h>
|
||
|
#include <ATen/core/op_registration/op_registration.h>
|
||
|
#include <ATen/CPUGenerator.h>
|
||
|
|
||
|
#include "aten_mlir_bridge.h"
|
||
|
#include "aten_mlir_type.h"
|
||
|
|
||
|
namespace torch_mlir {{
|
||
|
|
||
|
{funcs}
|
||
|
|
||
|
{regs}
|
||
|
}} // namespace torch_mlir
|
||
|
"""
|
||
|
|
||
|
_torch_mlir_FUNCTIONS = {}
|
||
|
|
||
|
_CTOR_FUNCTIONS = {
|
||
|
'empty': '.device(at::DeviceType::CPU)',
|
||
|
'linspace': '.device(at::DeviceType::CPU)',
|
||
|
'logspace': '.device(at::DeviceType::CPU)',
|
||
|
'rand': '.device(at::DeviceType::CPU)',
|
||
|
'rand_like': '.device(at::DeviceType::CPU)',
|
||
|
'randn': '.device(at::DeviceType::CPU)',
|
||
|
'randn_like': '.device(at::DeviceType::CPU)',
|
||
|
'randint': '.device(at::DeviceType::CPU)',
|
||
|
'randint_like': '.device(at::DeviceType::CPU)',
|
||
|
'randperm': '.device(at::DeviceType::CPU)',
|
||
|
'scalar_tensor': '.device(at::DeviceType::CPU)',
|
||
|
}
|
||
|
|
||
|
_FUNCTION_OPTIONS = {
|
||
|
'slice(Tensor, int64_t, int64_t, int64_t, int64_t) -> Tensor':
|
||
|
FuncOpts(wparams=['self']),
|
||
|
}
|
||
|
|
||
|
_RESULT_NAME = 'x_result'
|
||
|
|
||
|
|
||
|
class Context(object):
|
||
|
|
||
|
def __init__(self, functions):
|
||
|
with open(functions, 'r') as ff:
|
||
|
self.functions_data = ff.read()
|
||
|
|
||
|
def get_function(self, name):
|
||
|
if self.functions_data.find(' {}('.format(name)) >= 0:
|
||
|
return 'at::{}'.format(name)
|
||
|
|
||
|
|
||
|
class StringEmit(object):
|
||
|
|
||
|
def __init__(self, sref):
|
||
|
self.sref = sref
|
||
|
self.sval = ''
|
||
|
self.pos = -1
|
||
|
|
||
|
def __repr__(self):
|
||
|
return self.sval
|
||
|
|
||
|
def advance(self, t):
|
||
|
start = t.column - 1
|
||
|
end = t.end_column - 1
|
||
|
pos = self.pos if self.pos >= 0 else start
|
||
|
if start > pos:
|
||
|
self.sval += self.sref[pos:start]
|
||
|
self.sval += t.value
|
||
|
self.pos = end
|
||
|
|
||
|
def skip(self, t):
|
||
|
self.pos = last_match(t) if self.pos >= 0 else -1
|
||
|
|
||
|
def append(self, s):
|
||
|
self.sval += s
|
||
|
self.pos = -1
|
||
|
|
||
|
|
||
|
class TensorFetcher(object):
|
||
|
|
||
|
def __init__(self, var_name):
|
||
|
self.var_name = var_name
|
||
|
self.tvar_name = '{}_tensors'.format(self.var_name)
|
||
|
self.tensors = []
|
||
|
self.writeable = []
|
||
|
|
||
|
def add(self, name, writeable):
|
||
|
if writeable:
|
||
|
self.writeable.append(len(self.tensors))
|
||
|
self.tensors.append(name)
|
||
|
return '{}[{}]'.format(self.var_name, len(self.tensors) - 1)
|
||
|
|
||
|
def generate_fetches(self):
|
||
|
code = ''
|
||
|
code += ' std::vector<at::Tensor> {} = {{{}}};\n'.format(
|
||
|
self.tvar_name, ', '.join(self.tensors))
|
||
|
code += (' auto {} = bridge::MLIRCreateTensorList({});\n').format(
|
||
|
self.var_name, self.tvar_name)
|
||
|
return code
|
||
|
|
||
|
def generate_updates(self):
|
||
|
assert (0)
|
||
|
code = ''
|
||
|
if self.writeable:
|
||
|
ivar_name = '{}_update_indices'.format(self.var_name)
|
||
|
code += ' std::vector<size_t> {} = {{{}}};\n'.format(
|
||
|
ivar_name, ', '.join(str(x) for x in self.writeable))
|
||
|
code += ' bridge::XlaUpdateTensors({}, {}, {});\n'.format(
|
||
|
self.tvar_name, self.var_name, ivar_name)
|
||
|
return code
|
||
|
|
||
|
|
||
|
def list_get(l, n):
|
||
|
return l[n] if n < len(l) else None
|
||
|
|
||
|
|
||
|
def is_blacklisted_fn(fname, mapsig):
|
||
|
if fname in _FN_BLACKLIST or mapsig in _FN_BLACKLIST:
|
||
|
return True
|
||
|
for frx in _FN_BLACKLIST_REGEX:
|
||
|
if re.match(frx, fname) or re.match(frx, mapsig):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def get_outfn_options(fname, mapsig):
|
||
|
for name in [fname, mapsig]:
|
||
|
fnopts = _FN_OUT.get(name, None)
|
||
|
if fnopts is not None:
|
||
|
return fnopts
|
||
|
for frx, fnopts in _FN_OUT_REGEX:
|
||
|
if re.match(frx, fname) or re.match(frx, mapsig):
|
||
|
return fnopts
|
||
|
|
||
|
|
||
|
def get_remapfn_options(fname, mapsig):
|
||
|
for name in [fname, mapsig]:
|
||
|
fnopts = _FN_REMAP.get(name, None)
|
||
|
if fnopts is not None:
|
||
|
return fnopts
|
||
|
|
||
|
|
||
|
def is_write_param(fnopts, pname, defval):
|
||
|
if fnopts and fnopts.wparams:
|
||
|
if pname in fnopts.wparams:
|
||
|
return True
|
||
|
return defval
|
||
|
|
||
|
|
||
|
def first_match(t):
|
||
|
if isinstance(t, lark.lexer.Token):
|
||
|
return t.column - 1
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
return first_match(t.children[0])
|
||
|
|
||
|
|
||
|
def last_match(t):
|
||
|
if isinstance(t, lark.lexer.Token):
|
||
|
return t.end_column - 1
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
return last_match(t.children[-1])
|
||
|
|
||
|
|
||
|
def for_every_token(t, fn):
|
||
|
if isinstance(t, lark.lexer.Token):
|
||
|
fn(t)
|
||
|
else:
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
for c in t.children:
|
||
|
for_every_token(c, fn)
|
||
|
|
||
|
|
||
|
def emit_string(t, emit, emit_fn):
|
||
|
status = emit_fn(t)
|
||
|
if status > 0:
|
||
|
|
||
|
def do_emit(tok):
|
||
|
emit.advance(tok)
|
||
|
|
||
|
for_every_token(t, do_emit)
|
||
|
elif status == 0:
|
||
|
if isinstance(t, lark.lexer.Token):
|
||
|
emit.advance(t)
|
||
|
else:
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
for c in t.children:
|
||
|
emit_string(c, emit, emit_fn)
|
||
|
else:
|
||
|
emit.skip(t)
|
||
|
|
||
|
|
||
|
def typed_child(t, n, ttype):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
assert n < len(t.children)
|
||
|
c = t.children[n]
|
||
|
assert isinstance(c, lark.tree.Tree)
|
||
|
assert c.data == ttype, t.pretty()
|
||
|
return c
|
||
|
|
||
|
|
||
|
def rewrite_sig(tree, orig_sig, emit_fn=lambda x: 0):
|
||
|
emit = StringEmit(orig_sig)
|
||
|
emit_string(tree, emit, emit_fn)
|
||
|
return str(emit)
|
||
|
|
||
|
|
||
|
def rewrite_signature(sig, tmap):
|
||
|
|
||
|
def rewrite(t):
|
||
|
if t.type == 'TNAME':
|
||
|
new_type = tmap.get(t.value, None)
|
||
|
if new_type is not None:
|
||
|
t.value = new_type
|
||
|
|
||
|
def emit_fn(t):
|
||
|
if isinstance(t, lark.lexer.Token):
|
||
|
return 0
|
||
|
return -1 if t.data == 'param_defval' else 0
|
||
|
|
||
|
xtree = _XPARSER.parse(sig)
|
||
|
for_every_token(xtree, rewrite)
|
||
|
return rewrite_sig(xtree, sig, emit_fn=emit_fn)
|
||
|
|
||
|
|
||
|
def create_stdfunc_sig(tree, orig_sig):
|
||
|
|
||
|
def emit_fn(t):
|
||
|
if isinstance(t, lark.lexer.Token):
|
||
|
return 0
|
||
|
return -1 if t.data == 'param_name' else 0
|
||
|
|
||
|
emit = StringEmit(orig_sig)
|
||
|
# Emit full function return type.
|
||
|
emit_string(typed_child(tree, 0, 'type'), emit, emit_fn)
|
||
|
emit.append('(')
|
||
|
# Emit parameter list w/out parameter names.
|
||
|
emit_string(typed_child(tree, 3, 'params'), emit, emit_fn)
|
||
|
emit.append(')')
|
||
|
return str(emit)
|
||
|
|
||
|
|
||
|
def create_map_sig(tree, orig_sig):
|
||
|
|
||
|
def emit_fn(t):
|
||
|
if isinstance(t, lark.lexer.Token):
|
||
|
return -1 if t.type in ['CONST', 'REF', 'PTR'] else 0
|
||
|
return -1 if t.data in ['param_name', 'param_defval'] else 0
|
||
|
|
||
|
emit = StringEmit(orig_sig)
|
||
|
# Emit full function return type.
|
||
|
emit_string(typed_child(tree, 1, 'fnname'), emit, emit_fn)
|
||
|
emit.append('(')
|
||
|
# Emit parameter list w/out parameter names.
|
||
|
emit_string(typed_child(tree, 3, 'params'), emit, emit_fn)
|
||
|
emit.append(') -> ')
|
||
|
emit_string(typed_child(tree, 0, 'type'), emit, emit_fn)
|
||
|
return str(emit)
|
||
|
|
||
|
|
||
|
def type_core(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
for c in t.children:
|
||
|
if isinstance(c, lark.tree.Tree) and c.data == 'core_type':
|
||
|
c = c.children[0]
|
||
|
if isinstance(c, lark.lexer.Token):
|
||
|
return c.value
|
||
|
assert isinstance(c, lark.tree.Tree) and c.data == 'template'
|
||
|
return c.children[0].value
|
||
|
raise RuntimeError('Not a type tree: {}'.format(t))
|
||
|
|
||
|
|
||
|
def type_is_const(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
c = t.children[0]
|
||
|
return isinstance(c, lark.lexer.Token) and c.value == 'const'
|
||
|
|
||
|
|
||
|
def type_is_refptr(t, kind):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
c = t.children[-1]
|
||
|
if not isinstance(c, lark.tree.Tree) or c.data != 'refspec':
|
||
|
return False
|
||
|
c = c.children[0]
|
||
|
return isinstance(c, lark.lexer.Token) and c.value == kind
|
||
|
|
||
|
|
||
|
def extract_list(t, l):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
l.append(t.children[0])
|
||
|
if len(t.children) == 2:
|
||
|
c = t.children[1]
|
||
|
if isinstance(c, lark.tree.Tree) and c.data == t.data:
|
||
|
extract_list(c, l)
|
||
|
return l
|
||
|
|
||
|
|
||
|
def tuple_type_list(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
c = t.children[0]
|
||
|
assert isinstance(c, lark.tree.Tree) and c.data == 'core_type'
|
||
|
c = c.children[0]
|
||
|
assert isinstance(c, lark.tree.Tree) and c.data == 'template'
|
||
|
types = []
|
||
|
return extract_list(c.children[1], types)
|
||
|
|
||
|
|
||
|
def get_function_name(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
fname = t.children[1]
|
||
|
assert isinstance(fname, lark.tree.Tree)
|
||
|
assert fname.data == 'fnname'
|
||
|
return fname.children[0].value
|
||
|
|
||
|
|
||
|
def get_function_signature(t, orig_sig, namefn):
|
||
|
emit = StringEmit(orig_sig)
|
||
|
# Emit full function return type.
|
||
|
emit_string(typed_child(t, 0, 'type'), emit, lambda t: 0)
|
||
|
fnname = typed_child(t, 1, 'fnname').children[0]
|
||
|
xfname = namefn(fnname.value)
|
||
|
emit.append(' {}('.format(xfname))
|
||
|
# Emit parameter list w/out parameter names.
|
||
|
emit_string(typed_child(t, 3, 'params'), emit, lambda t: 0)
|
||
|
emit.append(')')
|
||
|
return str(emit), fnname.value, xfname
|
||
|
|
||
|
|
||
|
def get_parameters(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
c = t.children[2]
|
||
|
assert isinstance(c, lark.tree.Tree)
|
||
|
assert c.data == 'params'
|
||
|
params = []
|
||
|
extract_list(c, params)
|
||
|
return params
|
||
|
|
||
|
|
||
|
def get_rparameters(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
params = []
|
||
|
print(len(t.children))
|
||
|
# c = t.children[3]
|
||
|
# assert isinstance(c, lark.tree.Tree)
|
||
|
# assert c.data == 'rparams'
|
||
|
|
||
|
# extract_list(c, params)
|
||
|
return params
|
||
|
|
||
|
|
||
|
def param_name(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
c = t.children[1]
|
||
|
assert isinstance(c, lark.tree.Tree)
|
||
|
assert c.data == 'param_name'
|
||
|
token = c.children[0]
|
||
|
assert isinstance(token, lark.lexer.Token)
|
||
|
return token.value
|
||
|
|
||
|
|
||
|
def param_type(t):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
c = t.children[0]
|
||
|
assert isinstance(c, lark.tree.Tree)
|
||
|
return c
|
||
|
|
||
|
|
||
|
def get_optional(fnopts, name, defval=None):
|
||
|
if fnopts is None or not hasattr(fnopts, name):
|
||
|
return defval
|
||
|
return getattr(fnopts, name, defval) or defval
|
||
|
|
||
|
|
||
|
def get_return_value(rtype, rname, param, var, ref_param, fnopts):
|
||
|
crtype = type_core(rtype)
|
||
|
if type_is_const(rtype) or type_is_refptr(rtype, '&'):
|
||
|
# If the return type is a const or a reference, return the matching
|
||
|
# parameter. In these cases we operated on XLA tensors data (the ATEN one),
|
||
|
# but the returned references are the input parameters.
|
||
|
assert param
|
||
|
return param_name(param)
|
||
|
elif crtype != 'Tensor':
|
||
|
return rname
|
||
|
else:
|
||
|
# If instead the return type is a value Tensor, we create a new one by
|
||
|
# wrapping the proper local variable which has been created by calling
|
||
|
# into the CPU tensor implementation.
|
||
|
return 'bridge::CreateMLIRTensor({}, bridge::GetMLIRDevice({}))'.format(
|
||
|
rname, get_optional(fnopts, 'device_param', param_name(ref_param)))
|
||
|
|
||
|
|
||
|
def get_reference_param(params, fnopts=None):
|
||
|
# The reference parameter is the Tensor object which we use to extract the
|
||
|
# result Tensor device, if any.
|
||
|
ref_param = None
|
||
|
other = None
|
||
|
for p in params:
|
||
|
ptype = param_type(p)
|
||
|
cptype = type_core(ptype)
|
||
|
pname = param_name(p)
|
||
|
if get_optional(fnopts, 'ref_param') == pname:
|
||
|
return p
|
||
|
if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'):
|
||
|
other = p
|
||
|
if cptype != 'Tensor':
|
||
|
continue
|
||
|
if not ref_param and (pname == 'self' or type_is_const(ptype)):
|
||
|
ref_param = p
|
||
|
other = p
|
||
|
return ref_param or other
|
||
|
|
||
|
|
||
|
def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param,
|
||
|
fnopts):
|
||
|
types = tuple_type_list(rtype)
|
||
|
retstr = '{}('.format(rtype_str)
|
||
|
for i, ttype in enumerate(types):
|
||
|
if i > 0:
|
||
|
retstr += ', '
|
||
|
tuple_var = 'std::get<{}>({})'.format(i, rname)
|
||
|
retstr += get_return_value(ttype, tuple_var, list_get(params, i),
|
||
|
list_get(param_vars, i), ref_param, fnopts)
|
||
|
return retstr + ')'
|
||
|
|
||
|
|
||
|
def get_return_type_str(t, orig_sig):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
fname = t.children[1]
|
||
|
assert isinstance(fname, lark.tree.Tree)
|
||
|
assert fname.data == 'fnname'
|
||
|
token = fname.children[0]
|
||
|
assert isinstance(token, lark.lexer.Token)
|
||
|
return orig_sig[0:token.column - 2]
|
||
|
|
||
|
|
||
|
def generate_entry_debug_code(t, fname, params, fname_ns='aten'):
|
||
|
code = ''
|
||
|
if fname in _FN_NO_DEBUG_ENTRY_LIST:
|
||
|
return code
|
||
|
code += ' std::cout << "{}::{}" << std::endl;\n'.format(fname_ns, fname)
|
||
|
# Emits debug code for a given intercepted ATEN type function. For now we use
|
||
|
# a counter which will show up in the metrics reports.
|
||
|
# VLOG info. Use the following to see debug output:
|
||
|
# export TF_CPP_VMODULE=aten_mlir_type_default=3
|
||
|
#code += ' TF_VLOG(3) << "XLA {} :"'.format(fname)
|
||
|
#for p in params:
|
||
|
# ptype = param_type(p)
|
||
|
# cptype = type_core(ptype)
|
||
|
# pname = param_name(p)
|
||
|
# if cptype == 'Tensor':
|
||
|
# code += ' << " {}=" << {}.toString()'.format(pname, pname)
|
||
|
#code += ';\n'
|
||
|
return code
|
||
|
|
||
|
|
||
|
def generate_exit_debug_code(t, fname, rname, params, param_vars):
|
||
|
code = ''
|
||
|
return code
|
||
|
|
||
|
|
||
|
def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars,
|
||
|
ref_param, fnopts):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
rtype = t.children[0]
|
||
|
ctype = type_core(rtype)
|
||
|
if ctype == 'std::tuple':
|
||
|
retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars,
|
||
|
ref_param, fnopts)
|
||
|
elif ctype == 'std::vector':
|
||
|
#retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format(
|
||
|
# rname, get_optional(fnopts, 'device_param', param_name(ref_param)))
|
||
|
retstr = rname
|
||
|
elif ctype == 'Tensor':
|
||
|
retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param,
|
||
|
fnopts)
|
||
|
elif ctype == 'void' and not type_is_refptr(rtype, '*'):
|
||
|
return ''
|
||
|
else:
|
||
|
retstr = rname
|
||
|
return ' return {};\n'.format(retstr)
|
||
|
|
||
|
|
||
|
def generate_result_assignment(t, rname):
|
||
|
assert isinstance(t, lark.tree.Tree)
|
||
|
rtype = t.children[0]
|
||
|
ctype = type_core(rtype)
|
||
|
if ctype == 'void' and not type_is_refptr(rtype, '*'):
|
||
|
return ''
|
||
|
return 'auto&& {} = '.format(rname)
|
||
|
|
||
|
|
||
|
def get_handling_function(ctx, fname, the_ref_param, param_vars):
|
||
|
function = _torch_mlir_FUNCTIONS.get(fname, None) or ctx.get_function(fname)
|
||
|
if function:
|
||
|
code = '{}({})'.format(function, ', '.join(param_vars))
|
||
|
else:
|
||
|
other_params = list(param_vars)
|
||
|
other_params.remove(the_ref_param)
|
||
|
code = '{}.{}({})'.format(the_ref_param, fname, ', '.join(other_params))
|
||
|
return code
|
||
|
|
||
|
|
||
|
def rewrite_tensor_options(fname, pname):
|
||
|
rw = _CTOR_FUNCTIONS.get(fname, None)
|
||
|
if rw is None:
|
||
|
return '', pname
|
||
|
xname = 'o_{}'.format(pname)
|
||
|
code = ' at::TensorOptions {} = {}{};\n'.format(xname, pname, rw)
|
||
|
return code, xname
|
||
|
|
||
|
|
||
|
def get_param_names(params):
|
||
|
param_vars = []
|
||
|
for p in params:
|
||
|
pname = param_name(p)
|
||
|
param_vars.append(pname)
|
||
|
return param_vars
|
||
|
|
||
|
|
||
|
def expand_fn_template(tmpl, param_vars):
|
||
|
mdict = {}
|
||
|
for i, pname in enumerate(param_vars):
|
||
|
mdict[str(i)] = pname
|
||
|
return tmpl.substitute(mdict)
|
||
|
|
||
|
|
||
|
def create_call(fname, param_vars):
|
||
|
return '{}({})'.format(fname, ', '.join(param_vars))
|
||
|
|
||
|
|
||
|
def generate_shape_checks(param_vars, shape_check_indices, fname):
|
||
|
code = ''
|
||
|
#for i, j in shape_check_indices:
|
||
|
# code += (' XLA_CHECK({}.sizes() == {}.sizes()) << "Operand shapes must be '
|
||
|
# 'identical for {}, mismatch for arguments {} and {}";\n').format(
|
||
|
# param_vars[i], param_vars[j], fname, i + 1, j + 1)
|
||
|
return code
|
||
|
|
||
|
|
||
|
def generate_aten_remap(ctx, fname, sig, params, fnopts):
|
||
|
code = '{} {{\n'.format(sig)
|
||
|
|
||
|
param_vars = get_param_names(params)
|
||
|
if fnopts.outfn_template is not None:
|
||
|
fcall = expand_fn_template(fnopts.outfn_template, param_vars)
|
||
|
else:
|
||
|
assert fnopts.outfn_name
|
||
|
fcall = create_call(fnopts.outfn_name, param_vars)
|
||
|
|
||
|
if fnopts.shape_check_indices is not None:
|
||
|
code += generate_shape_checks(param_vars, fnopts.shape_check_indices, fname)
|
||
|
code += ' return {};\n'.format(fcall)
|
||
|
code += '}'
|
||
|
return code
|
||
|
|
||
|
|
||
|
def generate_outfn_result_copy(dest, src):
|
||
|
return ' {}.unsafeGetTensorImpl()->shallow_copy_from({}.getIntrusivePtr());\n'.format(
|
||
|
dest, src)
|
||
|
|
||
|
|
||
|
def generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
|
||
|
rtype = tree.children[0]
|
||
|
num_outputs = None
|
||
|
if type_core(rtype) == 'std::tuple':
|
||
|
num_outputs = len(tuple_type_list(rtype))
|
||
|
|
||
|
code = '{} {{\n'.format(sig)
|
||
|
code += generate_entry_debug_code(tree, fname, params)
|
||
|
|
||
|
param_vars = get_param_names(params)
|
||
|
if fnopts.outfn_template is not None:
|
||
|
fcall = expand_fn_template(fnopts.outfn_template, param_vars)
|
||
|
else:
|
||
|
m = re.match(r'(.*)_out$', fname)
|
||
|
assert m is not None, fname
|
||
|
out_count = num_outputs if num_outputs is not None else 1
|
||
|
fcall = create_call('ATenMLIRType::{}'.format(m.group(1)),
|
||
|
param_vars[out_count:])
|
||
|
|
||
|
tmp_result = '{}_tmp'.format(fname)
|
||
|
code += ' auto {} = {};\n'.format(tmp_result, fcall)
|
||
|
if num_outputs is None:
|
||
|
code += generate_outfn_result_copy(param_vars[0], tmp_result)
|
||
|
code += generate_exit_debug_code(tree, fname, param_vars[0], params,
|
||
|
param_vars)
|
||
|
code += ' return {};\n'.format(param_vars[0])
|
||
|
else:
|
||
|
for i in range(0, num_outputs):
|
||
|
code += generate_outfn_result_copy(
|
||
|
param_vars[i], 'std::get<{}>({})'.format(i, tmp_result))
|
||
|
code += generate_exit_debug_code(tree, fname, param_vars[0:num_outputs],
|
||
|
params, param_vars)
|
||
|
code += ' return {}('.format(get_return_type_str(rwxtree, rwsig))
|
||
|
for i in range(0, num_outputs):
|
||
|
if i > 0:
|
||
|
code += ', '
|
||
|
code += param_vars[i]
|
||
|
code += ');\n'
|
||
|
code += '}'
|
||
|
return code
|
||
|
|
||
|
|
||
|
def generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig, params,
|
||
|
fnopts):
|
||
|
ref_param = get_reference_param(params, fnopts=fnopts)
|
||
|
|
||
|
code = '{} {{\n'.format(sig)
|
||
|
code += generate_entry_debug_code(tree, fname, params)
|
||
|
the_ref_param = param_name(ref_param) if ref_param else None
|
||
|
tfetcher = TensorFetcher('mlirtens')
|
||
|
param_vars = []
|
||
|
for p in params:
|
||
|
ptype = param_type(p)
|
||
|
cptype = type_core(ptype)
|
||
|
pname = param_name(p)
|
||
|
if cptype == 'TensorList':
|
||
|
#xname = 'l_{}'.format(pname)
|
||
|
#code += (' auto {} = bridge::XlaCreateTensorList({});\n').format(
|
||
|
# xname, pname)
|
||
|
xname = pname
|
||
|
param_vars.append(xname)
|
||
|
elif cptype == 'TensorOptions':
|
||
|
gcode, xname = rewrite_tensor_options(fname, pname)
|
||
|
code += gcode
|
||
|
param_vars.append(xname)
|
||
|
elif cptype != 'Tensor':
|
||
|
param_vars.append(pname)
|
||
|
elif type_is_const(ptype):
|
||
|
xname = tfetcher.add(pname, is_write_param(fnopts, pname, False))
|
||
|
param_vars.append(xname)
|
||
|
else:
|
||
|
xname = tfetcher.add(pname, is_write_param(fnopts, pname, True))
|
||
|
param_vars.append(xname)
|
||
|
if p == ref_param and not get_optional(fnopts, 'ref_param'):
|
||
|
the_ref_param = param_vars[-1]
|
||
|
code += tfetcher.generate_fetches()
|
||
|
result_assign = generate_result_assignment(tree, _RESULT_NAME)
|
||
|
code += ' {}{};\n'.format(
|
||
|
result_assign, get_handling_function(ctx, fname, the_ref_param,
|
||
|
param_vars))
|
||
|
#code += tfetcher.generate_updates()
|
||
|
if result_assign:
|
||
|
code += (' static_cast<void>({}); // Avoid warnings in case not '
|
||
|
'used\n'.format(_RESULT_NAME))
|
||
|
code += generate_exit_debug_code(tree, fname,
|
||
|
_RESULT_NAME if result_assign else None,
|
||
|
params, param_vars)
|
||
|
code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname,
|
||
|
_RESULT_NAME if result_assign else None, params,
|
||
|
param_vars, ref_param, fnopts)
|
||
|
code += '}'
|
||
|
return code
|
||
|
|
||
|
|
||
|
def get_mlir_wrapper(fndef, ctx):
|
||
|
tree = _PARSER.parse(fndef.cpp_sig)
|
||
|
xtree = _XPARSER.parse(fndef.cpp_sig)
|
||
|
mapsig = create_map_sig(xtree, fndef.cpp_sig)
|
||
|
rwsig = rewrite_signature(fndef.cpp_sig, _TYPE_NSMAP)
|
||
|
rwxtree = _XPARSER.parse(rwsig)
|
||
|
params = get_parameters(tree)
|
||
|
fnopts = _FUNCTION_OPTIONS.get(mapsig, None)
|
||
|
|
||
|
def gen_fnname(x):
|
||
|
return 'ATenMLIRTypeDefault::{}'.format(x)
|
||
|
|
||
|
sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname)
|
||
|
if not is_blacklisted_fn(fname, mapsig):
|
||
|
ofnopts = get_outfn_options(fname, mapsig)
|
||
|
rfnopts = get_remapfn_options(fname, mapsig)
|
||
|
if ofnopts is not None:
|
||
|
#print ("gen_aten_out:", fname)
|
||
|
code = generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params,
|
||
|
ofnopts)
|
||
|
elif rfnopts is not None:
|
||
|
#print ("gen_aten_remap", fname)
|
||
|
code = generate_aten_remap(ctx, fname, sig, params, rfnopts)
|
||
|
else:
|
||
|
code = generate_aten_to_mlir(ctx, tree, rwxtree, fname, sig, rwsig,
|
||
|
params, fnopts)
|
||
|
else:
|
||
|
code = None
|
||
|
return FuncGen(tree=tree,
|
||
|
xtree=xtree,
|
||
|
rwxtree=rwxtree,
|
||
|
func=fname,
|
||
|
xfunc=xfname,
|
||
|
code=code,
|
||
|
sig=fndef.cpp_sig,
|
||
|
rwsig=rwsig,
|
||
|
cppsig=sig,
|
||
|
mapsig=mapsig,
|
||
|
funsig=create_stdfunc_sig(rwxtree, rwsig),
|
||
|
aten_sig=fndef.aten_sig)
|
||
|
|
||
|
|
||
|
def is_tensor_api(fndef):
|
||
|
fndef = fndef.replace('at::', '')
|
||
|
fndef = fndef.replace('c10::Device', 'Device')
|
||
|
m = re.search(r'\bTensor\b', fndef)
|
||
|
return m is not None, fndef
|
||
|
|
||
|
|
||
|
def extract_functions(path):
|
||
|
functions = []
|
||
|
errors = []
|
||
|
for line in open(path, 'r'):
|
||
|
m = re.match(r'\s*([^\s].*); //\s+(.*)', line)
|
||
|
if not m:
|
||
|
continue
|
||
|
fndef = m.group(1)
|
||
|
try:
|
||
|
_XPARSER.parse(fndef)
|
||
|
functions.append(FuncDef(cpp_sig=fndef, aten_sig=m.group(2)))
|
||
|
except Exception as e:
|
||
|
if is_tensor_api(fndef)[0]:
|
||
|
errors.append((fndef, str(e)))
|
||
|
print('Error parsing "{}": {}'.format(fndef, e), file=sys.stderr)
|
||
|
return functions, errors
|
||
|
|
||
|
|
||
|
def get_mapsig_key(mapsig):
|
||
|
# PyTorch generates std::tuple<> without space among the tuple types,
|
||
|
# which would require special understanding in the string rewriter.
|
||
|
# Since we are using this as simple key, we can just string the spaces.
|
||
|
return mapsig.replace(' ', '')
|
||
|
|
||
|
|
||
|
def parse_local_overrides(path):
|
||
|
functions = []
|
||
|
fndef = None
|
||
|
for line in open(path, 'r'):
|
||
|
line = line.strip()
|
||
|
if not fndef:
|
||
|
m = re.match(r'static\s+(.*);', line)
|
||
|
if m:
|
||
|
functions.append(m.group(1))
|
||
|
continue
|
||
|
m = re.match(r'static\s+(.*)', line)
|
||
|
if m:
|
||
|
fndef = m.group(1)
|
||
|
else:
|
||
|
fndef = '{} {}'.format(fndef, line)
|
||
|
if fndef.endswith(';'):
|
||
|
functions.append(fndef[:-1])
|
||
|
fndef = None
|
||
|
assert fndef is None
|
||
|
|
||
|
overrides = {}
|
||
|
for fndef in functions:
|
||
|
# Discard static XLA type functions which are not ATEN.
|
||
|
is_tensor, fndef = is_tensor_api(fndef)
|
||
|
if is_tensor:
|
||
|
xtree = _XPARSER.parse(fndef)
|
||
|
mapsig_key = get_mapsig_key(create_map_sig(xtree, fndef))
|
||
|
overrides[mapsig_key] = fndef
|
||
|
return overrides
|
||
|
|
||
|
|
||
|
def get_dialect_name(func):
|
||
|
name = ''
|
||
|
upper = True
|
||
|
cs = list(func)
|
||
|
for c in cs:
|
||
|
if c == '_':
|
||
|
upper = True
|
||
|
elif upper:
|
||
|
name += str(c).upper()
|
||
|
upper = False
|
||
|
else:
|
||
|
name += c
|
||
|
if cs[-1] == "_":
|
||
|
name += "Under"
|
||
|
return name
|
||
|
|
||
|
|
||
|
def generate_td_functions(fgens, overrides):
|
||
|
code = ''
|
||
|
overridden = set()
|
||
|
|
||
|
code += "#ifdef ATEN_OP_DEFS\n"
|
||
|
code += "#else\n"
|
||
|
code += "#define ATEN_OP_DEFS\n\n"
|
||
|
|
||
|
for fgen in fgens:
|
||
|
mapsig_key = get_mapsig_key(fgen.mapsig)
|
||
|
if mapsig_key in overrides:
|
||
|
overridden.add(mapsig_key)
|
||
|
if fgen.func in _TD_BLACKLIST:
|
||
|
continue
|
||
|
|
||
|
rtype = fgen.tree.children[0]
|
||
|
num_outputs = 1
|
||
|
if type_core(rtype) == 'std::tuple':
|
||
|
num_outputs = len(tuple_type_list(rtype))
|
||
|
#print(num_outputs, rtype)
|
||
|
|
||
|
dialect_name = get_dialect_name(fgen.func)
|
||
|
#print ('"{}"'.format(dialect_name))
|
||
|
code += 'def aten_{}Op: aten_Op<"{}"'.format(dialect_name, fgen.func)
|
||
|
code += ', [NoSideEffect'
|
||
|
if not fgen.func in _TD_NO_OPSTATS_LIST:
|
||
|
code += ', StatisticsOpInterface'
|
||
|
code += ']>,\n'
|
||
|
code += ' Results<(outs'
|
||
|
# foreach output
|
||
|
# rparams = get_rparameters(fgen.tree)
|
||
|
# for p in rparams:
|
||
|
# pname = param_name(p)
|
||
|
# ptype = param_type(p)
|
||
|
# cptype = type_core(ptype)
|
||
|
# print(pname)
|
||
|
code += ' AnyTensor'
|
||
|
for i in range(num_outputs - 1):
|
||
|
code += ', AnyTensor'
|
||
|
code += ')> {\n'
|
||
|
code += ' let arguments = (\n'
|
||
|
params = get_parameters(fgen.tree)
|
||
|
for p in params:
|
||
|
pname = param_name(p)
|
||
|
ptype = param_type(p)
|
||
|
cptype = type_core(ptype)
|
||
|
if (cptype == 'Tensor'):
|
||
|
td_type = "AnyTensor"
|
||
|
elif (cptype == 'Scalar' or cptype == 'int64_t' or cptype == 'double' or
|
||
|
cptype == 'bool'):
|
||
|
td_type = "AnyScalar"
|
||
|
elif (cptype == 'c10::optional' or cptype == 'std::array'):
|
||
|
continue
|
||
|
elif (cptype == 'IntArrayRef'):
|
||
|
td_type = "AnyType"
|
||
|
else:
|
||
|
print('unhandled type', cptype)
|
||
|
td_type = "AnyType"
|
||
|
if p == params[0]:
|
||
|
code += ' ins {}:${}'.format(td_type, pname)
|
||
|
else:
|
||
|
code += ',\n {}:${}'.format(td_type, pname)
|
||
|
code += '\n );\n'
|
||
|
code += ' let summary = "aten {} operator";\n'.format(fgen.func)
|
||
|
code += ' let description = [{\n'
|
||
|
code += ' {}Op\n'.format(dialect_name)
|
||
|
code += ' aten {} operator\n'.format(fgen.func)
|
||
|
code += ' }];\n'
|
||
|
if not fgen.func in _TD_NO_OPSTATS_LIST:
|
||
|
code += ' let extraClassDeclaration = [{\n'
|
||
|
code += ' std::map<std::string, uint64_t> getStatistics();\n'
|
||
|
code += ' }];\n'
|
||
|
code += '}\n\n'
|
||
|
|
||
|
code += "#endif\n"
|
||
|
return code, overridden
|
||
|
|
||
|
|
||
|
def generate_registrations(fgens, overrides):
|
||
|
code = 'void RegisterAtenTypeFunctions() {\n'
|
||
|
code += ' static auto dispatch = torch::RegisterOperators()\n'
|
||
|
overridden = set()
|
||
|
for fgen in fgens:
|
||
|
mapsig_key = get_mapsig_key(fgen.mapsig)
|
||
|
if mapsig_key in overrides:
|
||
|
override_fn = 'ATenMLIRType::{}'.format(fgen.func)
|
||
|
overridden.add(mapsig_key)
|
||
|
else:
|
||
|
override_fn = fgen.xfunc if fgen.code else None
|
||
|
if override_fn:
|
||
|
code += (
|
||
|
' .op(torch::RegisterOperators::options().schema("{}")\n '
|
||
|
'.impl_unboxedOnlyKernel<{}, &{}>(at::TensorTypeId::XLATensorId)\n'
|
||
|
' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format(
|
||
|
fgen.aten_sig, fgen.funsig, override_fn, override_fn,
|
||
|
fgen.aten_sig))
|
||
|
return code + ';\n}\n', overridden
|
||
|
|
||
|
|
||
|
def generate_functions(fgens):
|
||
|
code = ''
|
||
|
for fgen in fgens:
|
||
|
if fgen.code:
|
||
|
code += '{}\n\n'.format(fgen.code)
|
||
|
return code
|
||
|
|
||
|
|
||
|
def generate_class_functions(fgens):
|
||
|
code = ''
|
||
|
for fgen in fgens:
|
||
|
if fgen.code:
|
||
|
code += ' static {};\n'.format(fgen.rwsig)
|
||
|
return code
|
||
|
|
||
|
|
||
|
def gen_output_file(args, name):
|
||
|
if not args.output_folder:
|
||
|
return sys.stdout
|
||
|
return open(os.path.join(args.output_folder, name), 'w')
|
||
|
|
||
|
|
||
|
def gen_h_output_file(args):
|
||
|
return gen_output_file(args, 'aten_mlir_type_default.h')
|
||
|
|
||
|
|
||
|
def gen_cpp_output_file(args):
|
||
|
return gen_output_file(args, 'aten_mlir_type_default.cpp')
|
||
|
|
||
|
|
||
|
def gen_td_output_file(args):
|
||
|
return gen_output_file(args, 'ATenOps.td')
|
||
|
|
||
|
|
||
|
def check_overrides(overrides, overridden):
|
||
|
misses = 0
|
||
|
for mapsig, cpp_sig in overrides.items():
|
||
|
mapsig_key = get_mapsig_key(mapsig)
|
||
|
if not mapsig_key in overridden:
|
||
|
misses += 1
|
||
|
print('ATenMLIRType function missed override: {}; // {}'.format(
|
||
|
cpp_sig, mapsig),
|
||
|
file=sys.stderr)
|
||
|
return misses == 0
|
||
|
|
||
|
|
||
|
def generate(args):
|
||
|
fndefs, errors = extract_functions(args.typedef)
|
||
|
print('Extracted {} functions ({} errors) from {}'.format(
|
||
|
len(fndefs), len(errors), args.typedef),
|
||
|
file=sys.stderr)
|
||
|
assert len(errors) == 0
|
||
|
|
||
|
overrides = parse_local_overrides(args.overridetype)
|
||
|
print('{} function overrides in {}'.format(len(overrides), args.overridetype),
|
||
|
file=sys.stderr)
|
||
|
|
||
|
fgens = []
|
||
|
ctx = Context(args.functions)
|
||
|
for ts in fndefs:
|
||
|
try:
|
||
|
fgen = get_mlir_wrapper(ts, ctx)
|
||
|
if fgen:
|
||
|
fgens.append(fgen)
|
||
|
except Exception as e:
|
||
|
print('Failed to generate wrapper for {}: {}'.format(ts, e),
|
||
|
file=sys.stderr)
|
||
|
print('Generated {} wrappers for {}'.format(len(fgens), args.typedef),
|
||
|
file=sys.stderr)
|
||
|
|
||
|
functions = generate_functions(fgens)
|
||
|
hfunctions = generate_class_functions(fgens)
|
||
|
|
||
|
tdfunctions, overridden = generate_td_functions(fgens, overrides)
|
||
|
assert check_overrides(overrides, overridden)
|
||
|
#print(tdfunctions)
|
||
|
|
||
|
regs, overridden = generate_registrations(fgens, overrides)
|
||
|
#print (len(overrides), len(overridden))
|
||
|
assert check_overrides(overrides, overridden)
|
||
|
# Create output files ...
|
||
|
print(_H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=hfunctions),
|
||
|
file=gen_h_output_file(args))
|
||
|
print(_CPP_HEADER.format(gen=os.path.basename(sys.argv[0]),
|
||
|
funcs=functions,
|
||
|
regs=regs),
|
||
|
file=gen_cpp_output_file(args))
|
||
|
|
||
|
with gen_td_output_file(args) as f:
|
||
|
f.write(tdfunctions)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
arg_parser = argparse.ArgumentParser()
|
||
|
arg_parser.add_argument('--output_folder', type=str)
|
||
|
arg_parser.add_argument('overridetype',
|
||
|
type=str,
|
||
|
metavar='OVERRIDE_TYPE_FILE',
|
||
|
help='The path to the overrides file')
|
||
|
arg_parser.add_argument('typedef',
|
||
|
type=str,
|
||
|
metavar='TYPE_DEFAULT_FILE',
|
||
|
help='The path to the TypeDefault.h file')
|
||
|
arg_parser.add_argument('functions',
|
||
|
type=str,
|
||
|
metavar='FUNCTIONS_FILE',
|
||
|
help='The path to the Functions.h file')
|
||
|
args, files = arg_parser.parse_known_args()
|
||
|
generate(args)
|