LTC->MLIR Debug Info support (#1922)

* LTC->MLIR Debug Info support

* SW-95317 Propagate Lazy->Jit->MLIR scope name.

* Enhance location information based on op names

Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.

* Update locations logic; updated debug-info.py test

* Use {scope}/{op_name} format to track names by default

---------

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
pull/2369/head
Gleb Kazantaev 2023-08-02 10:29:11 -04:00 committed by GitHub
parent 4c24472dea
commit fb52a73cbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 20 deletions

View File

@ -15,6 +15,7 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/config.h>
#include "torch-mlir-c/Registration.h"
#include "torch-mlir-c/Transforms.h"
#include "mlir-c/IR.h"
@ -205,13 +206,46 @@ void TorchMlirLoweringContext::AssignOutputOp(
const Output& output, torch::jit::Value* op) {
PRINT_FUNCTION();
// TODO (antoniojkim): Do we need this?
// auto torch_mlir_node =
// NodeCast<TorchMlirNode>(output.node, output.node->op());
// if (!torch_mlir_node->getPythonStacktrace().empty()) {
// op->node()->s_(
// c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace());
// }
auto torch_mlir_node =
NodeCast<TorchMlirNode>(output.node, output.node->op());
std::vector<std::string> source_files, functions;
std::vector<int64_t> line_numbers;
const auto& metadata = torch_mlir_node->metadata();
const auto& frames = metadata.frame_info;
if (!frames.empty()) {
static std::vector<std::string> g_roots =
string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":");
std::for_each(frames.rbegin(), frames.rend(),
[&](const torch::lazy::SourceLocation& location) {
functions.push_back(location.function);
line_numbers.push_back(location.line);
std::string file_name = location.file;
for (const std::string& root : g_roots) {
if (startswith(file_name, root)) {
// location.file starts with root, strip it off
file_name = file_name.substr(root.size());
break;
}
}
source_files.push_back(file_name);
});
if (!source_files.empty()) {
op->node()->ss_(
c10::Symbol::attr("source_files"), source_files);
op->node()->ss_(
c10::Symbol::attr("functions"), functions);
op->node()->is_(
c10::Symbol::attr("line_numbers"), line_numbers);
}
}
auto scope = ::c10::Symbol::scope(metadata.scope);
op->node()->setScope(
c10::make_intrusive<torch::jit::Scope>()->push(scope));
emitted_outputs_[output] = std::move(op);
}
@ -424,7 +458,11 @@ const std::string TorchMlirComputation::to_string() const {
*ss_ptr << std::string(part.data, part.length);
};
std::stringstream ss;
mlirOperationPrint(mlirModuleGetOperation(module_op_), print_callback, &ss);
// Setup flags for MLIR serialization.
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false);
mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss);
return ss.str();
}

View File

@ -22,6 +22,24 @@ std::string string_join(const std::vector<T>& v, const std::string& delimiter) {
return joined.str();
}
inline std::vector<std::string> string_split(
const std::string& str,
const std::string& sep
) {
std::vector<std::string> tokens;
std::size_t pos1 = str.find_first_not_of(sep);
while (pos1 != std::string::npos) {
std::size_t pos2 = str.find_first_of(sep, pos1);
if (pos2 == std::string::npos) {
tokens.push_back(str.substr(pos1));
pos1 = pos2;
} else {
tokens.push_back(str.substr(pos1, pos2 - pos1));
pos1 = str.find_first_not_of(sep, pos2 + 1);
}
}
return tokens;
}
/*
* Returns true if str starts with prefix

View File

@ -14,6 +14,14 @@ static T GetEnv(const std::string& name, const T& default_value = T(0)) {
return T(std::atoi(env));
}
static std::string GetEnvString(const std::string& name, const std::string& default_value) {
const char* env = std::getenv(name.c_str());
if (!env) {
return default_value;
}
return std::string(env);
}
static bool GetEnvBool(const char* name, bool defval) {
const char* env = std::getenv(name);
if (env == nullptr) {

View File

@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//
#include "torch/csrc/jit/python/pybind.h"
#include "torch/csrc/lazy/core/config.h"
#include "torch/csrc/lazy/backend/backend_interface.h"
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
@ -25,6 +26,7 @@ namespace py = pybind11;
namespace {
bool verbose = sys_util::GetEnv("VERBOSE", false);
bool ir_debug = sys_util::GetEnv("LTC_IR_DEBUG", false);
struct NoGilSection {
NoGilSection() : state(PyEval_SaveThread()) {}
@ -52,6 +54,11 @@ void Initialize() {
if (verbose) {
std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl;
}
if (ir_debug) {
FLAGS_torch_lazy_ir_debug = true;
std::cout << "Enabled lazy tensor IR debugging." << std::endl;
}
}
/**

View File

@ -11,6 +11,8 @@
#include "function_importer.h"
#include "ivalue_importer.h"
#include <c10/util/irange.h>
#include <ATen/TensorUtils.h>
#include <unordered_map>
@ -407,15 +409,53 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
torch::jit::Node *node) {
auto flc = node->sourceRange().file_line_col();
if (flc) {
MlirLocation loc = mlirLocationUnknownGet(context);
if (node->hasAttribute(c10::Symbol::attr("source_files"))) {
const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files"));
const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers"));
const auto &functions = node->ss(c10::Symbol::attr("functions"));
// Chain a sequence of calls to construct single MlirLocation.
for (const auto i : c10::irange(sourceFiles.size())) {
MlirLocation newLoc = mlirLocationNameGet(
context, toMlirStringRef(functions[i]),
mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]),
lineNumbers[i],
0 /* column is not available */
));
loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc));
}
if (sourceFiles.size() == 1) {
// Somehow a callstack depth of 1...
// Disambiguate function name from scope name below.
loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context));
}
} else if (auto flc = node->sourceRange().file_line_col()) {
const std::string &file = std::get<0>(*flc);
int line = std::get<1>(*flc);
int col = std::get<2>(*flc);
return mlirLocationFileLineColGet(context, toMlirStringRef(file), line,
col);
loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col);
}
return mlirLocationUnknownGet(context);
std::string locationName;
auto scopeName = node->scopeName();
if (!scopeName.empty()) {
locationName = scopeName;
}
if (const c10::FunctionSchema *schema = node->maybeSchema()) {
if (!locationName.empty()) {
locationName += "/";
}
locationName += schema->operator_name().name;
}
if (!locationName.empty()) {
loc = mlirLocationNameGet(context, toMlirStringRef(locationName), loc);
}
return loc;
}
std::vector<MlirType>

View File

@ -17,14 +17,11 @@ mb = ModuleBuilder()
@mb.import_function
@torch.jit.script
def add3(t0, t1, t2):
# TODO: Checks for debug info are quite hard with the new trailing debug
# attribute print. See if this can be improved.
# CHECK: loc({{.*}}debug-info.py":[[# @LINE + 1]]
# CHECK-DAG: torch.aten.add.Tensor {{.*}} loc("aten::add"({{.*}}debug-info.py":[[# @LINE + 1]]
intermediate = t0 + t1
# CHECK: loc({{.*}}debug-info.py":[[# @LINE + 1]]
final = intermediate + t2
return final
# CHECK-DAG: torch.aten.mul.Tensor {{.*}} loc("aten::mul"({{.*}}debug-info.py":[[# @LINE + 1]]
return intermediate * t2
# Verify again with debug info present. Just checking that it makes it in there.
mb.module.operation.print(enable_debug_info=True)
mb.module.operation.print(enable_debug_info=True, use_local_scope=True)
print()