From fb52a73cbe001b32073861b484f3561e0f793748 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Wed, 2 Aug 2023 10:29:11 -0400 Subject: [PATCH] LTC->MLIR Debug Info support (#1922) * LTC->MLIR Debug Info support * SW-95317 Propagate Lazy->Jit->MLIR scope name. * Enhance location information based on op names Currently, the location information attached to the ops just considers the filename, line number and column number. Attaching operation name would help identify the type of computation by just looking at the profile of execution. * Update locations logic; updated debug-info.py test * Use {scope}/{op_name} format to track names by default --------- Co-authored-by: Gleb Kazantaev Co-authored-by: Mark Browning Co-authored-by: Vimal Patel --- .../mlir_lowering_context.cpp | 54 ++++++++++++++++--- .../base_lazy_backend/utils/string_utils.h | 18 +++++++ .../csrc/base_lazy_backend/utils/sys_utils.h | 8 +++ .../reference_lazy_backend_pybind.cpp | 7 +++ .../jit_ir/csrc/torch_to_mlir_utils.cpp | 50 +++++++++++++++-- .../importer/jit_ir/node_import/debug-info.py | 11 ++-- 6 files changed, 128 insertions(+), 20 deletions(-) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 0182952f8..4823b4929 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include "torch-mlir-c/Registration.h" #include "torch-mlir-c/Transforms.h" #include "mlir-c/IR.h" @@ -205,13 +206,46 @@ void TorchMlirLoweringContext::AssignOutputOp( const Output& output, torch::jit::Value* op) { PRINT_FUNCTION(); - // TODO (antoniojkim): Do we need this? - // auto torch_mlir_node = - // NodeCast(output.node, output.node->op()); - // if (!torch_mlir_node->getPythonStacktrace().empty()) { - // op->node()->s_( - // c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace()); - // } + auto torch_mlir_node = + NodeCast(output.node, output.node->op()); + + std::vector source_files, functions; + std::vector line_numbers; + const auto& metadata = torch_mlir_node->metadata(); + const auto& frames = metadata.frame_info; + if (!frames.empty()) { + static std::vector g_roots = + string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); + + std::for_each(frames.rbegin(), frames.rend(), + [&](const torch::lazy::SourceLocation& location) { + functions.push_back(location.function); + line_numbers.push_back(location.line); + + std::string file_name = location.file; + for (const std::string& root : g_roots) { + if (startswith(file_name, root)) { + // location.file starts with root, strip it off + file_name = file_name.substr(root.size()); + break; + } + } + source_files.push_back(file_name); + }); + + if (!source_files.empty()) { + op->node()->ss_( + c10::Symbol::attr("source_files"), source_files); + op->node()->ss_( + c10::Symbol::attr("functions"), functions); + op->node()->is_( + c10::Symbol::attr("line_numbers"), line_numbers); + } + } + auto scope = ::c10::Symbol::scope(metadata.scope); + op->node()->setScope( + c10::make_intrusive()->push(scope)); + emitted_outputs_[output] = std::move(op); } @@ -424,7 +458,11 @@ const std::string TorchMlirComputation::to_string() const { *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; - mlirOperationPrint(mlirModuleGetOperation(module_op_), print_callback, &ss); + + // Setup flags for MLIR serialization. + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false); + mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss); return ss.str(); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h index c4c2ea79d..281331992 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h @@ -22,6 +22,24 @@ std::string string_join(const std::vector& v, const std::string& delimiter) { return joined.str(); } +inline std::vector string_split( + const std::string& str, + const std::string& sep +) { + std::vector tokens; + std::size_t pos1 = str.find_first_not_of(sep); + while (pos1 != std::string::npos) { + std::size_t pos2 = str.find_first_of(sep, pos1); + if (pos2 == std::string::npos) { + tokens.push_back(str.substr(pos1)); + pos1 = pos2; + } else { + tokens.push_back(str.substr(pos1, pos2 - pos1)); + pos1 = str.find_first_not_of(sep, pos2 + 1); + } + } + return tokens; +} /* * Returns true if str starts with prefix diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h index 6cb47895a..5ae149049 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h @@ -14,6 +14,14 @@ static T GetEnv(const std::string& name, const T& default_value = T(0)) { return T(std::atoi(env)); } +static std::string GetEnvString(const std::string& name, const std::string& default_value) { + const char* env = std::getenv(name.c_str()); + if (!env) { + return default_value; + } + return std::string(env); +} + static bool GetEnvBool(const char* name, bool defval) { const char* env = std::getenv(name); if (env == nullptr) { diff --git a/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp index b2ff81c67..c575d9dd2 100644 --- a/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp +++ b/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch/csrc/jit/python/pybind.h" +#include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/backend/backend_interface.h" #include @@ -25,6 +26,7 @@ namespace py = pybind11; namespace { bool verbose = sys_util::GetEnv("VERBOSE", false); +bool ir_debug = sys_util::GetEnv("LTC_IR_DEBUG", false); struct NoGilSection { NoGilSection() : state(PyEval_SaveThread()) {} @@ -52,6 +54,11 @@ void Initialize() { if (verbose) { std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl; } + + if (ir_debug) { + FLAGS_torch_lazy_ir_debug = true; + std::cout << "Enabled lazy tensor IR debugging." << std::endl; + } } /** diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index e0420022d..4296abdc8 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -11,6 +11,8 @@ #include "function_importer.h" #include "ivalue_importer.h" +#include + #include #include @@ -407,15 +409,53 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc, MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, torch::jit::Node *node) { - auto flc = node->sourceRange().file_line_col(); - if (flc) { + MlirLocation loc = mlirLocationUnknownGet(context); + + if (node->hasAttribute(c10::Symbol::attr("source_files"))) { + const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files")); + const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers")); + const auto &functions = node->ss(c10::Symbol::attr("functions")); + + // Chain a sequence of calls to construct single MlirLocation. + for (const auto i : c10::irange(sourceFiles.size())) { + MlirLocation newLoc = mlirLocationNameGet( + context, toMlirStringRef(functions[i]), + mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]), + lineNumbers[i], + 0 /* column is not available */ + )); + loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc)); + } + if (sourceFiles.size() == 1) { + // Somehow a callstack depth of 1... + // Disambiguate function name from scope name below. + loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context)); + } + } else if (auto flc = node->sourceRange().file_line_col()) { const std::string &file = std::get<0>(*flc); int line = std::get<1>(*flc); int col = std::get<2>(*flc); - return mlirLocationFileLineColGet(context, toMlirStringRef(file), line, - col); + loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col); } - return mlirLocationUnknownGet(context); + + std::string locationName; + auto scopeName = node->scopeName(); + if (!scopeName.empty()) { + locationName = scopeName; + } + + if (const c10::FunctionSchema *schema = node->maybeSchema()) { + if (!locationName.empty()) { + locationName += "/"; + } + locationName += schema->operator_name().name; + } + + if (!locationName.empty()) { + loc = mlirLocationNameGet(context, toMlirStringRef(locationName), loc); + } + + return loc; } std::vector diff --git a/test/python/importer/jit_ir/node_import/debug-info.py b/test/python/importer/jit_ir/node_import/debug-info.py index b6543ed61..f7b441a12 100644 --- a/test/python/importer/jit_ir/node_import/debug-info.py +++ b/test/python/importer/jit_ir/node_import/debug-info.py @@ -17,14 +17,11 @@ mb = ModuleBuilder() @mb.import_function @torch.jit.script def add3(t0, t1, t2): - # TODO: Checks for debug info are quite hard with the new trailing debug - # attribute print. See if this can be improved. - # CHECK: loc({{.*}}debug-info.py":[[# @LINE + 1]] + # CHECK-DAG: torch.aten.add.Tensor {{.*}} loc("aten::add"({{.*}}debug-info.py":[[# @LINE + 1]] intermediate = t0 + t1 - # CHECK: loc({{.*}}debug-info.py":[[# @LINE + 1]] - final = intermediate + t2 - return final + # CHECK-DAG: torch.aten.mul.Tensor {{.*}} loc("aten::mul"({{.*}}debug-info.py":[[# @LINE + 1]] + return intermediate * t2 # Verify again with debug info present. Just checking that it makes it in there. -mb.module.operation.print(enable_debug_info=True) +mb.module.operation.print(enable_debug_info=True, use_local_scope=True) print()