Additional information in error message (#2783)

See change in test for what the new message looks like.
pull/2831/head
James Newling 2024-01-30 08:28:08 -08:00 committed by GitHub
parent e18fcebd3a
commit 1e882f5803
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 4 deletions

View File

@ -9,6 +9,7 @@
#include "class_annotator.h"
#include <sstream>
#include <stdexcept>
using namespace torch_mlir;
@ -150,11 +151,26 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
}
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
std::vector<ArgAnnotation> argAnnotations,
const std::vector<ArgAnnotation> &argAnnotations,
torch::jit::Function *function) {
if (argAnnotations.size() != function->num_inputs()) {
throw std::invalid_argument("Arg annotations should have one entry per "
"function parameter (including self).");
std::ostringstream oss;
oss << "There must be one argument annotation per function parameter. "
<< "Including 'self' the number of argument annotations is: "
<< argAnnotations.size()
<< ". The number of function parameters is: " << function->num_inputs()
<< ". ";
const auto &args = function->getSchema().arguments();
if (args.size() > 0) {
oss << "The function signature is (";
oss << args[0];
for (auto iter = args.begin() + 1; iter != args.end(); iter++) {
oss << ", " << *iter;
}
oss << ')' << '.';
}
throw std::invalid_argument(oss.str());
}
if (!methodAnnotation.argAnnotations.has_value()) {
methodAnnotation.argAnnotations.emplace(function->num_inputs(),

View File

@ -33,7 +33,10 @@ except Exception as e:
try:
annotator.annotateArgs(class_type, ['forward'], [None])
except Exception as e:
# CHECK: Arg annotations should have one entry per function parameter (including self).
# CHECK: There must be one argument annotation per function parameter.
# CHECK-SAME: Including 'self' the number of argument annotations is: 1.
# CHECK-SAME: The number of function parameters is: 2.
# CHECK-SAME: The function signature is (__torch__.TestModule self, Tensor tensor)
print(e)
try: