diff --git a/pytest/NumpyCompiler/ufunc.py b/pytest/NumpyCompiler/ufunc.py index f91773a6c..24a302a8c 100644 --- a/pytest/NumpyCompiler/ufunc.py +++ b/pytest/NumpyCompiler/ufunc.py @@ -2,6 +2,7 @@ import numpy as np from npcomp.compiler import test_config +from npcomp.compiler.frontend import EmittedError import_global = test_config.create_import_dump_decorator() @@ -26,3 +27,14 @@ def global_add(): # CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType> # CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray return np.add(a, b) + + +@import_global(expect_error="ufunc call does not currently support keyword args" + ) +def keywords_not_supported(): + return np.add(a, b, out=b) + + +@import_global(expect_error="ufunc numpy.add expected 2 inputs but got 1") +def mismatched_arg_count(): + return np.add(a) diff --git a/python/npcomp/compiler/frontend.py b/python/npcomp/compiler/frontend.py index 7f830bcaf..93acfd960 100644 --- a/python/npcomp/compiler/frontend.py +++ b/python/npcomp/compiler/frontend.py @@ -7,6 +7,7 @@ Frontend to the compiler, allowing various ways to import code. import ast import inspect +import textwrap from typing import Optional from _npcomp.mlir import ir @@ -76,6 +77,7 @@ class ImportFrontend: filename = inspect.getsourcefile(f) source_lines, start_lineno = inspect.getsourcelines(f) source = "".join(source_lines) + source = textwrap.dedent(source) ast_root = ast.parse(source, filename=filename) ast.increment_lineno(ast_root, start_lineno - 1) ast_fd = ast_root.body[0] diff --git a/python/npcomp/compiler/importer.py b/python/npcomp/compiler/importer.py index c5065f0e5..70552c0b5 100644 --- a/python/npcomp/compiler/importer.py +++ b/python/npcomp/compiler/importer.py @@ -46,10 +46,13 @@ class FunctionContext: def check_partial_evaluated(self, result: PartialEvalResult): """Checks that a PartialEvalResult has evaluated without error.""" if result.type == PartialEvalType.ERROR: - exc_info = result.yields + exc_type, exc_value, tb = result.yields loc = self.current_loc - message = ("Error while evaluating value from environment:\n" + - "".join(traceback.format_exception(*exc_info))) + if issubclass(exc_type, UserReportableError): + message = exc_value.message + else: + message = ("Error while evaluating value from environment:\n" + + "".join(traceback.format_exception(exc_type, exc_value, tb))) ir.emit_error(loc, message) raise EmittedError(loc, message) if result.type == PartialEvalType.NOT_EVALUATED: @@ -480,15 +483,3 @@ class PartialEvalImporter(BaseNodeVisitor): partial_eval_result = name_ref.load(self.fctx.environment) logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result) self.partial_eval_result = partial_eval_result - - -class EmittedError(Exception): - """Exception subclass that indicates an error diagnostic has been emitted. - - By throwing, this lets us abort and handle at a higher level so as not - to duplicate diagnostics. - """ - - def __init__(self, loc, message): - super().__init__("%s (at %r)" % (message, loc)) - self.loc = loc diff --git a/python/npcomp/compiler/interfaces.py b/python/npcomp/compiler/interfaces.py index b9cba3586..034d40532 100644 --- a/python/npcomp/compiler/interfaces.py +++ b/python/npcomp/compiler/interfaces.py @@ -13,6 +13,7 @@ from .target import * __all__ = [ "Configuration", + "EmittedError", "Environment", "NameReference", "NameResolver", @@ -20,12 +21,52 @@ __all__ = [ "PartialEvalType", "PartialEvalResult", "LiveValueRef", + "UserReportableError", "ValueCoder", "ValueCoderChain", ] _NotImplementedType = type(NotImplemented) +################################################################################ +# Exceptions +################################################################################ + + +class EmittedError(Exception): + """Exception subclass that indicates an error diagnostic has been emitted. + + By throwing, this lets us abort and handle at a higher level so as not + to duplicate diagnostics. + """ + + def __init__(self, loc, message): + super().__init__(loc, message) + + @property + def loc(self): + return self.args[0] + + @property + def message(self): + return self.args[1] + + +class UserReportableError(Exception): + """Used to raise an error with a message that should be reported to the user. + + Raising this error indicates that the error message is well formed and + makes sense without a traceback. + """ + + def __init__(self, message): + super().__init__(message) + + @property + def message(self): + return self.args[0] + + ################################################################################ # Name resolution ################################################################################ @@ -177,8 +218,8 @@ class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")): @staticmethod def error_message(message: str) -> "PartialEvalResult": try: - raise RuntimeError(message) - except RuntimeError: + raise UserReportableError(message) + except UserReportableError: return PartialEvalResult.error() diff --git a/python/npcomp/compiler/test_config.py b/python/npcomp/compiler/test_config.py index 536f6a3c7..147f172b5 100644 --- a/python/npcomp/compiler/test_config.py +++ b/python/npcomp/compiler/test_config.py @@ -4,6 +4,7 @@ """Various configuration helpers for testing.""" import ast +import functools from . import logging from .frontend import * @@ -20,13 +21,33 @@ def create_import_dump_decorator(*, config = create_test_config(target_factory=target_factory) logging.debug("Testing with config: {}", config) - def decorator(f): + def do_import(f): fe = ImportFrontend(config=config) fe.import_global_function(f) print("// -----") print(fe.ir_module.to_asm()) return f + def decorator(*args, expect_error=None): + if len(args) == 0: + # Higher order decorator. + return functools.partial(decorator, expect_error=expect_error) + + assert len(args) == 1 + try: + return do_import(f=args[0]) + except EmittedError as e: + if expect_error and e.message == expect_error: + print("// EXPECTED_ERROR:", repr(e.message)) + pass + elif expect_error: + print("// MISMATCHED_ERROR:", repr(e.message)) + raise AssertionError("Expected error '{}' but got '{}'".format( + expect_error, e.message)) + else: + print("// UNEXPECTED_ERROR:", repr(e.message)) + raise e + return decorator