mirror of https://github.com/llvm/torch-mlir
Fix evaluation message reporting and add checks to tests.
parent
d5e3902b6d
commit
2d4b0843c1
|
@ -2,6 +2,7 @@
|
|||
|
||||
import numpy as np
|
||||
from npcomp.compiler import test_config
|
||||
from npcomp.compiler.frontend import EmittedError
|
||||
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
@ -26,3 +27,14 @@ def global_add():
|
|||
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
|
||||
return np.add(a, b)
|
||||
|
||||
|
||||
@import_global(expect_error="ufunc call does not currently support keyword args"
|
||||
)
|
||||
def keywords_not_supported():
|
||||
return np.add(a, b, out=b)
|
||||
|
||||
|
||||
@import_global(expect_error="ufunc numpy.add expected 2 inputs but got 1")
|
||||
def mismatched_arg_count():
|
||||
return np.add(a)
|
||||
|
|
|
@ -7,6 +7,7 @@ Frontend to the compiler, allowing various ways to import code.
|
|||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Optional
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
@ -76,6 +77,7 @@ class ImportFrontend:
|
|||
filename = inspect.getsourcefile(f)
|
||||
source_lines, start_lineno = inspect.getsourcelines(f)
|
||||
source = "".join(source_lines)
|
||||
source = textwrap.dedent(source)
|
||||
ast_root = ast.parse(source, filename=filename)
|
||||
ast.increment_lineno(ast_root, start_lineno - 1)
|
||||
ast_fd = ast_root.body[0]
|
||||
|
|
|
@ -46,10 +46,13 @@ class FunctionContext:
|
|||
def check_partial_evaluated(self, result: PartialEvalResult):
|
||||
"""Checks that a PartialEvalResult has evaluated without error."""
|
||||
if result.type == PartialEvalType.ERROR:
|
||||
exc_info = result.yields
|
||||
exc_type, exc_value, tb = result.yields
|
||||
loc = self.current_loc
|
||||
if issubclass(exc_type, UserReportableError):
|
||||
message = exc_value.message
|
||||
else:
|
||||
message = ("Error while evaluating value from environment:\n" +
|
||||
"".join(traceback.format_exception(*exc_info)))
|
||||
"".join(traceback.format_exception(exc_type, exc_value, tb)))
|
||||
ir.emit_error(loc, message)
|
||||
raise EmittedError(loc, message)
|
||||
if result.type == PartialEvalType.NOT_EVALUATED:
|
||||
|
@ -480,15 +483,3 @@ class PartialEvalImporter(BaseNodeVisitor):
|
|||
partial_eval_result = name_ref.load(self.fctx.environment)
|
||||
logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result)
|
||||
self.partial_eval_result = partial_eval_result
|
||||
|
||||
|
||||
class EmittedError(Exception):
|
||||
"""Exception subclass that indicates an error diagnostic has been emitted.
|
||||
|
||||
By throwing, this lets us abort and handle at a higher level so as not
|
||||
to duplicate diagnostics.
|
||||
"""
|
||||
|
||||
def __init__(self, loc, message):
|
||||
super().__init__("%s (at %r)" % (message, loc))
|
||||
self.loc = loc
|
||||
|
|
|
@ -13,6 +13,7 @@ from .target import *
|
|||
|
||||
__all__ = [
|
||||
"Configuration",
|
||||
"EmittedError",
|
||||
"Environment",
|
||||
"NameReference",
|
||||
"NameResolver",
|
||||
|
@ -20,12 +21,52 @@ __all__ = [
|
|||
"PartialEvalType",
|
||||
"PartialEvalResult",
|
||||
"LiveValueRef",
|
||||
"UserReportableError",
|
||||
"ValueCoder",
|
||||
"ValueCoderChain",
|
||||
]
|
||||
|
||||
_NotImplementedType = type(NotImplemented)
|
||||
|
||||
################################################################################
|
||||
# Exceptions
|
||||
################################################################################
|
||||
|
||||
|
||||
class EmittedError(Exception):
|
||||
"""Exception subclass that indicates an error diagnostic has been emitted.
|
||||
|
||||
By throwing, this lets us abort and handle at a higher level so as not
|
||||
to duplicate diagnostics.
|
||||
"""
|
||||
|
||||
def __init__(self, loc, message):
|
||||
super().__init__(loc, message)
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self.args[1]
|
||||
|
||||
|
||||
class UserReportableError(Exception):
|
||||
"""Used to raise an error with a message that should be reported to the user.
|
||||
|
||||
Raising this error indicates that the error message is well formed and
|
||||
makes sense without a traceback.
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self.args[0]
|
||||
|
||||
|
||||
################################################################################
|
||||
# Name resolution
|
||||
################################################################################
|
||||
|
@ -177,8 +218,8 @@ class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
|
|||
@staticmethod
|
||||
def error_message(message: str) -> "PartialEvalResult":
|
||||
try:
|
||||
raise RuntimeError(message)
|
||||
except RuntimeError:
|
||||
raise UserReportableError(message)
|
||||
except UserReportableError:
|
||||
return PartialEvalResult.error()
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"""Various configuration helpers for testing."""
|
||||
|
||||
import ast
|
||||
import functools
|
||||
|
||||
from . import logging
|
||||
from .frontend import *
|
||||
|
@ -20,13 +21,33 @@ def create_import_dump_decorator(*,
|
|||
config = create_test_config(target_factory=target_factory)
|
||||
logging.debug("Testing with config: {}", config)
|
||||
|
||||
def decorator(f):
|
||||
def do_import(f):
|
||||
fe = ImportFrontend(config=config)
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
def decorator(*args, expect_error=None):
|
||||
if len(args) == 0:
|
||||
# Higher order decorator.
|
||||
return functools.partial(decorator, expect_error=expect_error)
|
||||
|
||||
assert len(args) == 1
|
||||
try:
|
||||
return do_import(f=args[0])
|
||||
except EmittedError as e:
|
||||
if expect_error and e.message == expect_error:
|
||||
print("// EXPECTED_ERROR:", repr(e.message))
|
||||
pass
|
||||
elif expect_error:
|
||||
print("// MISMATCHED_ERROR:", repr(e.message))
|
||||
raise AssertionError("Expected error '{}' but got '{}'".format(
|
||||
expect_error, e.message))
|
||||
else:
|
||||
print("// UNEXPECTED_ERROR:", repr(e.message))
|
||||
raise e
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue