mirror of https://github.com/llvm/torch-mlir
Yapf Format `refbacked.py`.
parent
564403e3a1
commit
8ba77ae2a5
|
@ -21,7 +21,10 @@ __all__ = [
|
|||
|
||||
|
||||
def assert_arg_type_is_supported(ty):
|
||||
SUPPORTED = [np.float16, np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_]
|
||||
SUPPORTED = [
|
||||
np.float16, np.float32, np.float64, np.uint8, np.int8, np.int32,
|
||||
np.int64, np.bool_
|
||||
]
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
|
||||
|
||||
|
||||
|
@ -86,8 +89,9 @@ class RefBackendInvoker:
|
|||
|
||||
def consume_return_funcs(*args):
|
||||
self.result = tuple([
|
||||
arg if type in elemental_type_to_ctype else
|
||||
unranked_memref_to_numpy(arg, memref_type_to_np_dtype[type])
|
||||
arg if type in elemental_type_to_ctype
|
||||
else unranked_memref_to_numpy(
|
||||
arg, memref_type_to_np_dtype[type])
|
||||
for arg, type in zip(args, ret_types)
|
||||
])
|
||||
if len(self.result) == 1:
|
||||
|
|
Loading…
Reference in New Issue