mirror of https://github.com/llvm/torch-mlir
Add limited support for function arguments.
parent
e3fd22a035
commit
917fd94f94
|
@ -10,6 +10,13 @@ def import_global(f):
|
|||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
# CHECK-LABEL: func @positional_args
|
||||
# CHECK-SAME: (%arg0: !basicpy.UnknownType, %arg1: !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
@import_global
|
||||
def positional_args(a, b):
|
||||
# CHECK: basicpy.binary_expr %arg0 "Add" %arg1
|
||||
return a + b
|
||||
|
||||
# CHECK-LABEL: func @pass_no_return
|
||||
@import_global
|
||||
def pass_no_return():
|
||||
|
|
|
@ -79,17 +79,25 @@ class ImportFrontend:
|
|||
logging.debug(":::::::")
|
||||
logging.debug("::: Importing global function {}:\n{}", ast_fd.name,
|
||||
ast.dump(ast_fd, include_attributes=True))
|
||||
|
||||
# TODO: VERY BAD: Assumes all positional params.
|
||||
unknown_type = h.basicpy_UnknownType
|
||||
f_params = inspect.signature(f).parameters
|
||||
arg_count = len(f_params)
|
||||
ir_f_type = h.function_type([unknown_type for _ in range(arg_count)],
|
||||
[unknown_type])
|
||||
|
||||
h.builder.set_file_line_col(filename_ident, ast_fd.lineno,
|
||||
ast_fd.col_offset)
|
||||
h.builder.insert_before_terminator(ir_m.first_block)
|
||||
ir_f_type = h.function_type([], [h.basicpy_UnknownType])
|
||||
ir_f = h.func_op(ast_fd.name, ir_f_type, create_entry_block=True)
|
||||
|
||||
fctx = FunctionContext(ir_c=ir_c,
|
||||
ir_f=ir_f,
|
||||
ir_h=h,
|
||||
filename_ident=filename_ident)
|
||||
for f_arg, ir_arg in zip(f_params, ir_f.first_block.args):
|
||||
fctx.map_local_name(f_arg, ir_arg)
|
||||
|
||||
fdimport = FunctionDefImporter(fctx, ast_fd)
|
||||
fdimport.import_body()
|
||||
return ir_f
|
||||
|
||||
|
|
Loading…
Reference in New Issue