//===- ATenOpReport.cpp -----------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Dialect/ATen/ATenOpReport.h" #include "npcomp/Dialect/ATen/ATenDialect.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Pass/Pass.h" #include #include #define DEBUG_TYPE "aten-op-stats" using namespace mlir; namespace { std::string getAsString(std::map &m, std::string &e) { return m.count(e) ? std::to_string(m[e]) : " "; } /// Query operations through the StatisticsOpInterface and print the result /// in a human-readable way. This replicates the functionality in various /// network analysis tools and is a stepping stone toward using the information /// as an analysis to drive optimization. struct ATenOpReportPass : public PassWrapper> { private: std::string *output; std::vector tableFields; std::map opToName; public: ATenOpReportPass() : output(nullptr), tableFields({"reads", "writes", "activation_in", "activation_out", "parameters_in", "ops:MAC", "ops:==", "ops:>", "ops:*", "ops:+", "ops:/", "ops:sqrt", "ops:-", "grad"}) {} ATenOpReportPass(std::string *output) : output(output), tableFields({"reads", "writes", "activation_in", "activation_out", "parameters_in", "ops:MAC", "ops:==", "ops:>", "ops:*", "ops:+", "ops:/", "ops:sqrt", "ops:-", "grad"}) {} std::string emitJSONReport() { llvm::json::Object top; auto graph = getOperation().lookupSymbol("graph"); graph.walk([&](Operation *op) { if (auto stats = mlir::dyn_cast(op)) { // name for this layer std::string layerName = opToName[op]; // raw stats for this layer std::map layerStatsMap = stats.getStatistics(); // JSON version of the stats we are building llvm::json::Object layerStatsJSON; // foreach string f in tableField, // get the sum of all entries in layerStatsMap containing f for (auto &f : tableFields) { for (auto &p : layerStatsMap) { if (p.first.find(f) != std::string::npos) { if (auto count = layerStatsJSON[f].getAsInteger()) layerStatsJSON[f] = (int64_t)p.second + *count; else layerStatsJSON[f] = (int64_t)p.second; } } } top[layerName] = llvm::json::Value(std::move(layerStatsJSON)); } }); llvm::json::Value topv(std::move(top)); std::string ret; llvm::raw_string_ostream ss(ret); ss << llvm::formatv("{0:2}", topv) << "\n"; return ss.str(); } void runOnOperation() override { // I don't change anything markAllAnalysesPreserved(); auto module = getOperation(); // check that a function called "graph" exists auto graph = module.lookupSymbol("graph"); if (!graph) { emitError(mlir::UnknownLoc::get(module.getContext()), "OpReportPass failed: can't find a graph function\n"); signalPassFailure(); return; } unsigned currentLayer = 0; opToName.clear(); graph.walk([&](Operation *op) { auto attr = op->getAttrOfType("layer_name"); if (attr) opToName[op] = attr.getValue().str(); else opToName[op] = "unknown-layer-" + std::to_string(currentLayer); currentLayer++; }); std::string report = emitJSONReport(); if(output) { *output = report; } else { graph.emitWarning(report); } } }; } // namespace namespace mlir { namespace NPCOMP { namespace aten { std::unique_ptr createATenOpReportPass() { return std::make_unique(); } std::unique_ptr createATenOpReportPass(std::string &report) { return std::make_unique(&report); } void mlir::NPCOMP::aten::registerATenOpReportPass() { PassRegistration("aten-op-report", "Generate ATen operation report"); } } // namespace aten } // namespace NPCOMP } // namespace mlir