Fix evaluation message reporting and add checks to tests.

pull/1/head
Stella Laurenzo 2020-06-29 17:48:17 -07:00
parent d5e3902b6d
commit 2d4b0843c1
5 changed files with 85 additions and 18 deletions

View File

@ -2,6 +2,7 @@
import numpy as np import numpy as np
from npcomp.compiler import test_config from npcomp.compiler import test_config
from npcomp.compiler.frontend import EmittedError
import_global = test_config.create_import_dump_decorator() import_global = test_config.create_import_dump_decorator()
@ -26,3 +27,14 @@ def global_add():
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType> # CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?> # CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
return np.add(a, b) return np.add(a, b)
@import_global(expect_error="ufunc call does not currently support keyword args"
)
def keywords_not_supported():
return np.add(a, b, out=b)
@import_global(expect_error="ufunc numpy.add expected 2 inputs but got 1")
def mismatched_arg_count():
return np.add(a)

View File

@ -7,6 +7,7 @@ Frontend to the compiler, allowing various ways to import code.
import ast import ast
import inspect import inspect
import textwrap
from typing import Optional from typing import Optional
from _npcomp.mlir import ir from _npcomp.mlir import ir
@ -76,6 +77,7 @@ class ImportFrontend:
filename = inspect.getsourcefile(f) filename = inspect.getsourcefile(f)
source_lines, start_lineno = inspect.getsourcelines(f) source_lines, start_lineno = inspect.getsourcelines(f)
source = "".join(source_lines) source = "".join(source_lines)
source = textwrap.dedent(source)
ast_root = ast.parse(source, filename=filename) ast_root = ast.parse(source, filename=filename)
ast.increment_lineno(ast_root, start_lineno - 1) ast.increment_lineno(ast_root, start_lineno - 1)
ast_fd = ast_root.body[0] ast_fd = ast_root.body[0]

View File

@ -46,10 +46,13 @@ class FunctionContext:
def check_partial_evaluated(self, result: PartialEvalResult): def check_partial_evaluated(self, result: PartialEvalResult):
"""Checks that a PartialEvalResult has evaluated without error.""" """Checks that a PartialEvalResult has evaluated without error."""
if result.type == PartialEvalType.ERROR: if result.type == PartialEvalType.ERROR:
exc_info = result.yields exc_type, exc_value, tb = result.yields
loc = self.current_loc loc = self.current_loc
message = ("Error while evaluating value from environment:\n" + if issubclass(exc_type, UserReportableError):
"".join(traceback.format_exception(*exc_info))) message = exc_value.message
else:
message = ("Error while evaluating value from environment:\n" +
"".join(traceback.format_exception(exc_type, exc_value, tb)))
ir.emit_error(loc, message) ir.emit_error(loc, message)
raise EmittedError(loc, message) raise EmittedError(loc, message)
if result.type == PartialEvalType.NOT_EVALUATED: if result.type == PartialEvalType.NOT_EVALUATED:
@ -480,15 +483,3 @@ class PartialEvalImporter(BaseNodeVisitor):
partial_eval_result = name_ref.load(self.fctx.environment) partial_eval_result = name_ref.load(self.fctx.environment)
logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result) logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result)
self.partial_eval_result = partial_eval_result self.partial_eval_result = partial_eval_result
class EmittedError(Exception):
"""Exception subclass that indicates an error diagnostic has been emitted.
By throwing, this lets us abort and handle at a higher level so as not
to duplicate diagnostics.
"""
def __init__(self, loc, message):
super().__init__("%s (at %r)" % (message, loc))
self.loc = loc

View File

@ -13,6 +13,7 @@ from .target import *
__all__ = [ __all__ = [
"Configuration", "Configuration",
"EmittedError",
"Environment", "Environment",
"NameReference", "NameReference",
"NameResolver", "NameResolver",
@ -20,12 +21,52 @@ __all__ = [
"PartialEvalType", "PartialEvalType",
"PartialEvalResult", "PartialEvalResult",
"LiveValueRef", "LiveValueRef",
"UserReportableError",
"ValueCoder", "ValueCoder",
"ValueCoderChain", "ValueCoderChain",
] ]
_NotImplementedType = type(NotImplemented) _NotImplementedType = type(NotImplemented)
################################################################################
# Exceptions
################################################################################
class EmittedError(Exception):
"""Exception subclass that indicates an error diagnostic has been emitted.
By throwing, this lets us abort and handle at a higher level so as not
to duplicate diagnostics.
"""
def __init__(self, loc, message):
super().__init__(loc, message)
@property
def loc(self):
return self.args[0]
@property
def message(self):
return self.args[1]
class UserReportableError(Exception):
"""Used to raise an error with a message that should be reported to the user.
Raising this error indicates that the error message is well formed and
makes sense without a traceback.
"""
def __init__(self, message):
super().__init__(message)
@property
def message(self):
return self.args[0]
################################################################################ ################################################################################
# Name resolution # Name resolution
################################################################################ ################################################################################
@ -177,8 +218,8 @@ class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
@staticmethod @staticmethod
def error_message(message: str) -> "PartialEvalResult": def error_message(message: str) -> "PartialEvalResult":
try: try:
raise RuntimeError(message) raise UserReportableError(message)
except RuntimeError: except UserReportableError:
return PartialEvalResult.error() return PartialEvalResult.error()

View File

@ -4,6 +4,7 @@
"""Various configuration helpers for testing.""" """Various configuration helpers for testing."""
import ast import ast
import functools
from . import logging from . import logging
from .frontend import * from .frontend import *
@ -20,13 +21,33 @@ def create_import_dump_decorator(*,
config = create_test_config(target_factory=target_factory) config = create_test_config(target_factory=target_factory)
logging.debug("Testing with config: {}", config) logging.debug("Testing with config: {}", config)
def decorator(f): def do_import(f):
fe = ImportFrontend(config=config) fe = ImportFrontend(config=config)
fe.import_global_function(f) fe.import_global_function(f)
print("// -----") print("// -----")
print(fe.ir_module.to_asm()) print(fe.ir_module.to_asm())
return f return f
def decorator(*args, expect_error=None):
if len(args) == 0:
# Higher order decorator.
return functools.partial(decorator, expect_error=expect_error)
assert len(args) == 1
try:
return do_import(f=args[0])
except EmittedError as e:
if expect_error and e.message == expect_error:
print("// EXPECTED_ERROR:", repr(e.message))
pass
elif expect_error:
print("// MISMATCHED_ERROR:", repr(e.message))
raise AssertionError("Expected error '{}' but got '{}'".format(
expect_error, e.message))
else:
print("// UNEXPECTED_ERROR:", repr(e.message))
raise e
return decorator return decorator