Fix evaluation message reporting and add checks to tests.

pull/1/head
Stella Laurenzo 2020-06-29 17:48:17 -07:00
parent d5e3902b6d
commit 2d4b0843c1
5 changed files with 85 additions and 18 deletions

View File

@ -2,6 +2,7 @@
import numpy as np
from npcomp.compiler import test_config
from npcomp.compiler.frontend import EmittedError
import_global = test_config.create_import_dump_decorator()
@ -26,3 +27,14 @@ def global_add():
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
return np.add(a, b)
@import_global(expect_error="ufunc call does not currently support keyword args"
)
def keywords_not_supported():
return np.add(a, b, out=b)
@import_global(expect_error="ufunc numpy.add expected 2 inputs but got 1")
def mismatched_arg_count():
return np.add(a)

View File

@ -7,6 +7,7 @@ Frontend to the compiler, allowing various ways to import code.
import ast
import inspect
import textwrap
from typing import Optional
from _npcomp.mlir import ir
@ -76,6 +77,7 @@ class ImportFrontend:
filename = inspect.getsourcefile(f)
source_lines, start_lineno = inspect.getsourcelines(f)
source = "".join(source_lines)
source = textwrap.dedent(source)
ast_root = ast.parse(source, filename=filename)
ast.increment_lineno(ast_root, start_lineno - 1)
ast_fd = ast_root.body[0]

View File

@ -46,10 +46,13 @@ class FunctionContext:
def check_partial_evaluated(self, result: PartialEvalResult):
"""Checks that a PartialEvalResult has evaluated without error."""
if result.type == PartialEvalType.ERROR:
exc_info = result.yields
exc_type, exc_value, tb = result.yields
loc = self.current_loc
message = ("Error while evaluating value from environment:\n" +
"".join(traceback.format_exception(*exc_info)))
if issubclass(exc_type, UserReportableError):
message = exc_value.message
else:
message = ("Error while evaluating value from environment:\n" +
"".join(traceback.format_exception(exc_type, exc_value, tb)))
ir.emit_error(loc, message)
raise EmittedError(loc, message)
if result.type == PartialEvalType.NOT_EVALUATED:
@ -480,15 +483,3 @@ class PartialEvalImporter(BaseNodeVisitor):
partial_eval_result = name_ref.load(self.fctx.environment)
logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result)
self.partial_eval_result = partial_eval_result
class EmittedError(Exception):
"""Exception subclass that indicates an error diagnostic has been emitted.
By throwing, this lets us abort and handle at a higher level so as not
to duplicate diagnostics.
"""
def __init__(self, loc, message):
super().__init__("%s (at %r)" % (message, loc))
self.loc = loc

View File

@ -13,6 +13,7 @@ from .target import *
__all__ = [
"Configuration",
"EmittedError",
"Environment",
"NameReference",
"NameResolver",
@ -20,12 +21,52 @@ __all__ = [
"PartialEvalType",
"PartialEvalResult",
"LiveValueRef",
"UserReportableError",
"ValueCoder",
"ValueCoderChain",
]
_NotImplementedType = type(NotImplemented)
################################################################################
# Exceptions
################################################################################
class EmittedError(Exception):
"""Exception subclass that indicates an error diagnostic has been emitted.
By throwing, this lets us abort and handle at a higher level so as not
to duplicate diagnostics.
"""
def __init__(self, loc, message):
super().__init__(loc, message)
@property
def loc(self):
return self.args[0]
@property
def message(self):
return self.args[1]
class UserReportableError(Exception):
"""Used to raise an error with a message that should be reported to the user.
Raising this error indicates that the error message is well formed and
makes sense without a traceback.
"""
def __init__(self, message):
super().__init__(message)
@property
def message(self):
return self.args[0]
################################################################################
# Name resolution
################################################################################
@ -177,8 +218,8 @@ class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
@staticmethod
def error_message(message: str) -> "PartialEvalResult":
try:
raise RuntimeError(message)
except RuntimeError:
raise UserReportableError(message)
except UserReportableError:
return PartialEvalResult.error()

View File

@ -4,6 +4,7 @@
"""Various configuration helpers for testing."""
import ast
import functools
from . import logging
from .frontend import *
@ -20,13 +21,33 @@ def create_import_dump_decorator(*,
config = create_test_config(target_factory=target_factory)
logging.debug("Testing with config: {}", config)
def decorator(f):
def do_import(f):
fe = ImportFrontend(config=config)
fe.import_global_function(f)
print("// -----")
print(fe.ir_module.to_asm())
return f
def decorator(*args, expect_error=None):
if len(args) == 0:
# Higher order decorator.
return functools.partial(decorator, expect_error=expect_error)
assert len(args) == 1
try:
return do_import(f=args[0])
except EmittedError as e:
if expect_error and e.message == expect_error:
print("// EXPECTED_ERROR:", repr(e.message))
pass
elif expect_error:
print("// MISMATCHED_ERROR:", repr(e.message))
raise AssertionError("Expected error '{}' but got '{}'".format(
expect_error, e.message))
else:
print("// UNEXPECTED_ERROR:", repr(e.message))
raise e
return decorator