diff --git a/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp index b144e946b..47f7a974c 100644 --- a/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp @@ -9,6 +9,7 @@ #include "class_annotator.h" +#include #include using namespace torch_mlir; @@ -150,11 +151,26 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { } static void fillArgAnnotations(MethodAnnotation &methodAnnotation, - std::vector argAnnotations, + const std::vector &argAnnotations, torch::jit::Function *function) { if (argAnnotations.size() != function->num_inputs()) { - throw std::invalid_argument("Arg annotations should have one entry per " - "function parameter (including self)."); + + std::ostringstream oss; + oss << "There must be one argument annotation per function parameter. " + << "Including 'self' the number of argument annotations is: " + << argAnnotations.size() + << ". The number of function parameters is: " << function->num_inputs() + << ". "; + const auto &args = function->getSchema().arguments(); + if (args.size() > 0) { + oss << "The function signature is ("; + oss << args[0]; + for (auto iter = args.begin() + 1; iter != args.end(); iter++) { + oss << ", " << *iter; + } + oss << ')' << '.'; + } + throw std::invalid_argument(oss.str()); } if (!methodAnnotation.argAnnotations.has_value()) { methodAnnotation.argAnnotations.emplace(function->num_inputs(), diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py index 26eaa5bd0..0979d0422 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py @@ -33,7 +33,10 @@ except Exception as e: try: annotator.annotateArgs(class_type, ['forward'], [None]) except Exception as e: - # CHECK: Arg annotations should have one entry per function parameter (including self). + # CHECK: There must be one argument annotation per function parameter. + # CHECK-SAME: Including 'self' the number of argument annotations is: 1. + # CHECK-SAME: The number of function parameters is: 2. + # CHECK-SAME: The function signature is (__torch__.TestModule self, Tensor tensor) print(e) try: