mirror of https://github.com/llvm/torch-mlir
Fix evaluation message reporting and add checks to tests.
parent
d5e3902b6d
commit
2d4b0843c1
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from npcomp.compiler import test_config
|
from npcomp.compiler import test_config
|
||||||
|
from npcomp.compiler.frontend import EmittedError
|
||||||
|
|
||||||
import_global = test_config.create_import_dump_decorator()
|
import_global = test_config.create_import_dump_decorator()
|
||||||
|
|
||||||
|
@ -26,3 +27,14 @@ def global_add():
|
||||||
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
|
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
|
||||||
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
|
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
|
||||||
return np.add(a, b)
|
return np.add(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@import_global(expect_error="ufunc call does not currently support keyword args"
|
||||||
|
)
|
||||||
|
def keywords_not_supported():
|
||||||
|
return np.add(a, b, out=b)
|
||||||
|
|
||||||
|
|
||||||
|
@import_global(expect_error="ufunc numpy.add expected 2 inputs but got 1")
|
||||||
|
def mismatched_arg_count():
|
||||||
|
return np.add(a)
|
||||||
|
|
|
@ -7,6 +7,7 @@ Frontend to the compiler, allowing various ways to import code.
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
|
import textwrap
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from _npcomp.mlir import ir
|
from _npcomp.mlir import ir
|
||||||
|
@ -76,6 +77,7 @@ class ImportFrontend:
|
||||||
filename = inspect.getsourcefile(f)
|
filename = inspect.getsourcefile(f)
|
||||||
source_lines, start_lineno = inspect.getsourcelines(f)
|
source_lines, start_lineno = inspect.getsourcelines(f)
|
||||||
source = "".join(source_lines)
|
source = "".join(source_lines)
|
||||||
|
source = textwrap.dedent(source)
|
||||||
ast_root = ast.parse(source, filename=filename)
|
ast_root = ast.parse(source, filename=filename)
|
||||||
ast.increment_lineno(ast_root, start_lineno - 1)
|
ast.increment_lineno(ast_root, start_lineno - 1)
|
||||||
ast_fd = ast_root.body[0]
|
ast_fd = ast_root.body[0]
|
||||||
|
|
|
@ -46,10 +46,13 @@ class FunctionContext:
|
||||||
def check_partial_evaluated(self, result: PartialEvalResult):
|
def check_partial_evaluated(self, result: PartialEvalResult):
|
||||||
"""Checks that a PartialEvalResult has evaluated without error."""
|
"""Checks that a PartialEvalResult has evaluated without error."""
|
||||||
if result.type == PartialEvalType.ERROR:
|
if result.type == PartialEvalType.ERROR:
|
||||||
exc_info = result.yields
|
exc_type, exc_value, tb = result.yields
|
||||||
loc = self.current_loc
|
loc = self.current_loc
|
||||||
message = ("Error while evaluating value from environment:\n" +
|
if issubclass(exc_type, UserReportableError):
|
||||||
"".join(traceback.format_exception(*exc_info)))
|
message = exc_value.message
|
||||||
|
else:
|
||||||
|
message = ("Error while evaluating value from environment:\n" +
|
||||||
|
"".join(traceback.format_exception(exc_type, exc_value, tb)))
|
||||||
ir.emit_error(loc, message)
|
ir.emit_error(loc, message)
|
||||||
raise EmittedError(loc, message)
|
raise EmittedError(loc, message)
|
||||||
if result.type == PartialEvalType.NOT_EVALUATED:
|
if result.type == PartialEvalType.NOT_EVALUATED:
|
||||||
|
@ -480,15 +483,3 @@ class PartialEvalImporter(BaseNodeVisitor):
|
||||||
partial_eval_result = name_ref.load(self.fctx.environment)
|
partial_eval_result = name_ref.load(self.fctx.environment)
|
||||||
logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result)
|
logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result)
|
||||||
self.partial_eval_result = partial_eval_result
|
self.partial_eval_result = partial_eval_result
|
||||||
|
|
||||||
|
|
||||||
class EmittedError(Exception):
|
|
||||||
"""Exception subclass that indicates an error diagnostic has been emitted.
|
|
||||||
|
|
||||||
By throwing, this lets us abort and handle at a higher level so as not
|
|
||||||
to duplicate diagnostics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, loc, message):
|
|
||||||
super().__init__("%s (at %r)" % (message, loc))
|
|
||||||
self.loc = loc
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from .target import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Configuration",
|
"Configuration",
|
||||||
|
"EmittedError",
|
||||||
"Environment",
|
"Environment",
|
||||||
"NameReference",
|
"NameReference",
|
||||||
"NameResolver",
|
"NameResolver",
|
||||||
|
@ -20,12 +21,52 @@ __all__ = [
|
||||||
"PartialEvalType",
|
"PartialEvalType",
|
||||||
"PartialEvalResult",
|
"PartialEvalResult",
|
||||||
"LiveValueRef",
|
"LiveValueRef",
|
||||||
|
"UserReportableError",
|
||||||
"ValueCoder",
|
"ValueCoder",
|
||||||
"ValueCoderChain",
|
"ValueCoderChain",
|
||||||
]
|
]
|
||||||
|
|
||||||
_NotImplementedType = type(NotImplemented)
|
_NotImplementedType = type(NotImplemented)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Exceptions
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class EmittedError(Exception):
|
||||||
|
"""Exception subclass that indicates an error diagnostic has been emitted.
|
||||||
|
|
||||||
|
By throwing, this lets us abort and handle at a higher level so as not
|
||||||
|
to duplicate diagnostics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loc, message):
|
||||||
|
super().__init__(loc, message)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loc(self):
|
||||||
|
return self.args[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self):
|
||||||
|
return self.args[1]
|
||||||
|
|
||||||
|
|
||||||
|
class UserReportableError(Exception):
|
||||||
|
"""Used to raise an error with a message that should be reported to the user.
|
||||||
|
|
||||||
|
Raising this error indicates that the error message is well formed and
|
||||||
|
makes sense without a traceback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self):
|
||||||
|
return self.args[0]
|
||||||
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Name resolution
|
# Name resolution
|
||||||
################################################################################
|
################################################################################
|
||||||
|
@ -177,8 +218,8 @@ class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def error_message(message: str) -> "PartialEvalResult":
|
def error_message(message: str) -> "PartialEvalResult":
|
||||||
try:
|
try:
|
||||||
raise RuntimeError(message)
|
raise UserReportableError(message)
|
||||||
except RuntimeError:
|
except UserReportableError:
|
||||||
return PartialEvalResult.error()
|
return PartialEvalResult.error()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
"""Various configuration helpers for testing."""
|
"""Various configuration helpers for testing."""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
|
import functools
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
from .frontend import *
|
from .frontend import *
|
||||||
|
@ -20,13 +21,33 @@ def create_import_dump_decorator(*,
|
||||||
config = create_test_config(target_factory=target_factory)
|
config = create_test_config(target_factory=target_factory)
|
||||||
logging.debug("Testing with config: {}", config)
|
logging.debug("Testing with config: {}", config)
|
||||||
|
|
||||||
def decorator(f):
|
def do_import(f):
|
||||||
fe = ImportFrontend(config=config)
|
fe = ImportFrontend(config=config)
|
||||||
fe.import_global_function(f)
|
fe.import_global_function(f)
|
||||||
print("// -----")
|
print("// -----")
|
||||||
print(fe.ir_module.to_asm())
|
print(fe.ir_module.to_asm())
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
def decorator(*args, expect_error=None):
|
||||||
|
if len(args) == 0:
|
||||||
|
# Higher order decorator.
|
||||||
|
return functools.partial(decorator, expect_error=expect_error)
|
||||||
|
|
||||||
|
assert len(args) == 1
|
||||||
|
try:
|
||||||
|
return do_import(f=args[0])
|
||||||
|
except EmittedError as e:
|
||||||
|
if expect_error and e.message == expect_error:
|
||||||
|
print("// EXPECTED_ERROR:", repr(e.message))
|
||||||
|
pass
|
||||||
|
elif expect_error:
|
||||||
|
print("// MISMATCHED_ERROR:", repr(e.message))
|
||||||
|
raise AssertionError("Expected error '{}' but got '{}'".format(
|
||||||
|
expect_error, e.message))
|
||||||
|
else:
|
||||||
|
print("// UNEXPECTED_ERROR:", repr(e.message))
|
||||||
|
raise e
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue