mirror of https://github.com/llvm/torch-mlir
Additional information in error message (#2783)
See change in test for what the new message looks like.pull/2831/head
parent
e18fcebd3a
commit
1e882f5803
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
#include "class_annotator.h"
|
#include "class_annotator.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
@ -150,11 +151,26 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
||||||
std::vector<ArgAnnotation> argAnnotations,
|
const std::vector<ArgAnnotation> &argAnnotations,
|
||||||
torch::jit::Function *function) {
|
torch::jit::Function *function) {
|
||||||
if (argAnnotations.size() != function->num_inputs()) {
|
if (argAnnotations.size() != function->num_inputs()) {
|
||||||
throw std::invalid_argument("Arg annotations should have one entry per "
|
|
||||||
"function parameter (including self).");
|
std::ostringstream oss;
|
||||||
|
oss << "There must be one argument annotation per function parameter. "
|
||||||
|
<< "Including 'self' the number of argument annotations is: "
|
||||||
|
<< argAnnotations.size()
|
||||||
|
<< ". The number of function parameters is: " << function->num_inputs()
|
||||||
|
<< ". ";
|
||||||
|
const auto &args = function->getSchema().arguments();
|
||||||
|
if (args.size() > 0) {
|
||||||
|
oss << "The function signature is (";
|
||||||
|
oss << args[0];
|
||||||
|
for (auto iter = args.begin() + 1; iter != args.end(); iter++) {
|
||||||
|
oss << ", " << *iter;
|
||||||
|
}
|
||||||
|
oss << ')' << '.';
|
||||||
|
}
|
||||||
|
throw std::invalid_argument(oss.str());
|
||||||
}
|
}
|
||||||
if (!methodAnnotation.argAnnotations.has_value()) {
|
if (!methodAnnotation.argAnnotations.has_value()) {
|
||||||
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
||||||
|
|
|
@ -33,7 +33,10 @@ except Exception as e:
|
||||||
try:
|
try:
|
||||||
annotator.annotateArgs(class_type, ['forward'], [None])
|
annotator.annotateArgs(class_type, ['forward'], [None])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# CHECK: Arg annotations should have one entry per function parameter (including self).
|
# CHECK: There must be one argument annotation per function parameter.
|
||||||
|
# CHECK-SAME: Including 'self' the number of argument annotations is: 1.
|
||||||
|
# CHECK-SAME: The number of function parameters is: 2.
|
||||||
|
# CHECK-SAME: The function signature is (__torch__.TestModule self, Tensor tensor)
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue