torch-mlir/pytest/NumpyCompiler/ndarray_inference.py

22 lines
705 B
Python
Raw Normal View History

2020-07-05 08:43:04 +08:00
# RUN: %PYTHON %s | npcomp-opt -split-input-file -npcomp-cpa-type-inference | FileCheck %s --dump-input=fail
2020-07-04 07:38:10 +08:00
import numpy as np
from npcomp.compiler import test_config
from npcomp.compiler.frontend import EmittedError
import_global = test_config.create_import_dump_decorator()
global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
(2, 1)))
a = np.asarray([1.0, 2.0])
b = np.asarray([3.0, 4.0])
# Test the basic flow of invoking a ufunc call with constants captured from
# a global using explicit function syntax (np.add(a, b)).
# CHECK-LABEL: func @global_add
@import_global
def global_add():
return np.add(a, b)