mirror of https://github.com/llvm/torch-mlir
Additional information in error message (#2783)
See change in test for what the new message looks like.pull/2831/head
parent
e18fcebd3a
commit
1e882f5803
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "class_annotator.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
@ -150,11 +151,26 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
|
|||
}
|
||||
|
||||
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
||||
std::vector<ArgAnnotation> argAnnotations,
|
||||
const std::vector<ArgAnnotation> &argAnnotations,
|
||||
torch::jit::Function *function) {
|
||||
if (argAnnotations.size() != function->num_inputs()) {
|
||||
throw std::invalid_argument("Arg annotations should have one entry per "
|
||||
"function parameter (including self).");
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "There must be one argument annotation per function parameter. "
|
||||
<< "Including 'self' the number of argument annotations is: "
|
||||
<< argAnnotations.size()
|
||||
<< ". The number of function parameters is: " << function->num_inputs()
|
||||
<< ". ";
|
||||
const auto &args = function->getSchema().arguments();
|
||||
if (args.size() > 0) {
|
||||
oss << "The function signature is (";
|
||||
oss << args[0];
|
||||
for (auto iter = args.begin() + 1; iter != args.end(); iter++) {
|
||||
oss << ", " << *iter;
|
||||
}
|
||||
oss << ')' << '.';
|
||||
}
|
||||
throw std::invalid_argument(oss.str());
|
||||
}
|
||||
if (!methodAnnotation.argAnnotations.has_value()) {
|
||||
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
||||
|
|
|
@ -33,7 +33,10 @@ except Exception as e:
|
|||
try:
|
||||
annotator.annotateArgs(class_type, ['forward'], [None])
|
||||
except Exception as e:
|
||||
# CHECK: Arg annotations should have one entry per function parameter (including self).
|
||||
# CHECK: There must be one argument annotation per function parameter.
|
||||
# CHECK-SAME: Including 'self' the number of argument annotations is: 1.
|
||||
# CHECK-SAME: The number of function parameters is: 2.
|
||||
# CHECK-SAME: The function signature is (__torch__.TestModule self, Tensor tensor)
|
||||
print(e)
|
||||
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue