Add ability to annotate TorchScript classes.

The first use case is to annotate certain program constructs as either
exported or private. In this commit we plumb it down to
GlobalizeObjectGraph which makes use of this information.

Recommended review order:
1. class_annotator.h/.cpp + `test/module_import/annotations/*`
    - New abstractions to communicate with Python code and annotate.
2. IR changes in TorchOps.td
    - Adding "private" attribute to various things.
3. ivalue_import.cpp changes
    - Module + ClassAnnotator = annotated IR
4. GlobalizeObjectGraph.cpp + tests
    - use new "private" attributes to create "private" IR.
    - also, tweak some of the op deleting mechanics, which was triggering
      some memory errors / assertions

With this, we can run the classifier through and inline it as follows:
```
frontends/pytorch/utils/pt_util.py --import --exported-name forward ~/tmp/classifier.pt \
| npcomp-opt -torch-globalize-object-graph -inline
```
IR: https://gist.github.com/silvasean/32dcad9f6270557f412094a77cecdd69
pull/169/head
Sean Silva 2021-02-19 16:21:21 -08:00
parent c424c24ed8
commit a375ccf9da
18 changed files with 702 additions and 36 deletions

View File

@ -12,6 +12,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(NPCOMPTorchMLIRExt SHARED
builder/acap_dispatch.cpp
builder/class_annotator.cpp
builder/debug.cpp
builder/func_builder.cpp
builder/graph_importer.cpp

View File

@ -0,0 +1,213 @@
//===- class_annotator.cpp ------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "class_annotator.h"
#include <stdexcept>
using namespace torch_mlir;
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
// Prefix every line of `s` with `linePrefix`.
static std::string indentString(const std::string &linePrefix,
const std::string &s) {
std::stringstream is(s);
std::stringstream os;
std::string line;
while (std::getline(is, line)) {
os << linePrefix << line << "\n";
}
return os.str();
}
//===----------------------------------------------------------------------===//
// ClassAnnotation
//===----------------------------------------------------------------------===//
ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType)
: classType(classType) {
attributeAnnotations.resize(classType->getAttributes().size());
methodAnnotations.resize(classType->methods().size());
}
std::vector<AttributeAnnotation> &
ClassAnnotation::getAttributeAnnotations() {
// Halfhearted attempt to ensure consistency if the class type has
// been mutated.
//
// We can't easily guard against attributes being removed and
// then other attributes being added, or types changed, etc. without
// effectively mirroring the entire ClassType.
assert(attributeAnnotations.size() == classType->getAttributes().size() &&
"annotations out of sync. class has been mutated");
return attributeAnnotations;
}
std::vector<MethodAnnotation> &
ClassAnnotation::getMethodAnnotations() {
// Halfhearted attempt to ensure consistency if the class type has
// been mutated.
//
// We can't easily guard against methods being removed, added, or changed.
assert(methodAnnotations.size() == classType->methods().size() &&
"annotations out of sync. class has been mutated");
return methodAnnotations;
}
//===----------------------------------------------------------------------===//
// ClassAnnotator
//===----------------------------------------------------------------------===//
static void exportNoneRecurse(ClassAnnotator &classAnnotator,
c10::ClassType *classType) {
ClassAnnotation &classAnnotation =
classAnnotator.getOrCreateClassAnnotation(classType);
for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) {
attributeAnnotation.isExported = false;
}
for (auto &methodAnnotation : classAnnotation.getMethodAnnotations()) {
methodAnnotation.isExported = false;
}
for (auto &classAttribute : classType->getAttributes()) {
if (auto childClassType = classAttribute.getType()->cast<c10::ClassType>()) {
exportNoneRecurse(classAnnotator, childClassType.get());
}
}
}
void ClassAnnotator::exportNone(c10::ClassType &rootClassType) {
exportNoneRecurse(*this, &rootClassType);
}
void ClassAnnotator::exportPath(std::vector<std::string> exportedPath,
c10::ClassType &rootClassType) {
if (exportedPath.size() == 0) {
throw std::invalid_argument(
"Empty exported path. Can only export a property of a class.");
}
c10::ClassType *classType = &rootClassType;
// Reverse so that pop_back gives us the initial atoms first.
std::reverse(exportedPath.begin(), exportedPath.end());
while (exportedPath.size() != 1) {
// This will throw in case of missing attribute.
c10::TypePtr childType = classType->getAttribute(exportedPath.back());
c10::ClassTypePtr childClassType = childType->cast<c10::ClassType>();
if (!childClassType) {
std::stringstream ss;
ss << "class '" << classType->name()->qualifiedName()
<< "' does not have a submodule in attribute '" << exportedPath.back()
<< "'";
throw std::invalid_argument(ss.str());
}
exportedPath.pop_back();
classType = childClassType.get();
}
if (!classType->findAttribute(exportedPath.back()) &&
!classType->findMethod(exportedPath.back())) {
std::stringstream ss;
ss << "class '" << classType->name()->qualifiedName()
<< "' does not have a method or attribute called '"
<< exportedPath.back() << "'";
throw std::invalid_argument(ss.str());
}
ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType);
std::vector<AttributeAnnotation> &attributeAnnotations =
classAnnotation.getAttributeAnnotations();
const std::vector<c10::ClassAttribute> &classAttributes =
classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) {
if (classAttributes[i].getName() == exportedPath.back()) {
attributeAnnotations[i].isExported = true;
}
}
std::vector<MethodAnnotation> &methodAnnotations =
classAnnotation.getMethodAnnotations();
const std::vector<torch::jit::Function *> &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) {
if (methods[i]->name() == exportedPath.back()) {
methodAnnotations[i].isExported = true;
}
}
}
const ClassAnnotationMap &ClassAnnotator::getAnnotationMap() {
return classAnnotations;
}
ClassAnnotation &
ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
auto it = classAnnotations.find(classType);
if (it == classAnnotations.end()) {
auto newAnnotation = std::make_unique<ClassAnnotation>(
classType->shared_from_this()->cast<c10::ClassType>());
it = classAnnotations.insert({classType, std::move(newAnnotation)}).first;
}
return *it->second;
}
//===----------------------------------------------------------------------===//
// toString methods
//===----------------------------------------------------------------------===//
std::string AttributeAnnotation::toString(const std::string &name) {
std::stringstream ss;
ss << "AttributeAnnotation('" << name << "') {\n";
ss << " isExported = " << (isExported ? "true" : "false") << "\n";
ss << "}\n";
return ss.str();
}
std::string MethodAnnotation::toString(const std::string &name) {
std::stringstream ss;
ss << "MethodAnnotation('" << name << "') {\n";
ss << " isExported = " << (isExported ? "true" : "false") << "\n";
ss << "}\n";
return ss.str();
}
std::string ClassAnnotation::toString() {
std::stringstream ss;
ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n";
const std::vector<c10::ClassAttribute> &classAttributes =
classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) {
ss << indentString(
" ", attributeAnnotations[i].toString(classAttributes[i].getName()));
}
const std::vector<torch::jit::Function *> &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) {
ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name()));
}
ss << "}\n";
return ss.str();
}
std::string ClassAnnotator::toString() {
std::stringstream ss;
ss << "ClassAnnotator {\n";
for (auto &p : classAnnotations) {
ss << indentString(" ", p.second->toString());
}
ss << "}\n";
return ss.str();
}
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
py::class_<ClassAnnotator>(m, "ClassAnnotator")
.def(py::init<>())
.def("exportPath", &ClassAnnotator::exportPath)
.def("exportNone", &ClassAnnotator::exportNone)
.def("__repr__", &ClassAnnotator::toString);
}

View File

@ -0,0 +1,125 @@
//===- class_annotations.h --------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
// Utilities for annotating Torch `c10::ClassType`
//
// We cannot intrusively add metadata to the ClassType, so we instead
// keep a parallel data structure.
//
// An annotation injects extra knowledge about the program which is not
// otherwise deducible. Thus, it is important that all annotations have a safe
// "no extra knowledge" state.
//
// Annotations should not be thought of at the MLIR level. They should express
// information at the level of the user-observable program semantics independent
// of implementation.
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_CLASS_ANNOTATOR_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_CLASS_ANNOTATOR_H
#include "../pybind.h"
namespace torch_mlir {
// An annotation on a class's attribute (corresponds to a c10::ClassAttribute).
struct AttributeAnnotation {
// Whether external access to this attribute is allowed.
// The default "no knowledge" state of the program is that all attributes
// can be externally accessed.
bool isExported = true;
std::string toString(const std::string &name);
};
// An annotation on a class's method (corresponds to a torch::jit::Function).
struct MethodAnnotation {
// Whether external calls to this method are allowed.
// The default "no knowledge" state of the program is that all methods
// can be externally called.
bool isExported = true;
std::string toString(const std::string &name);
};
// Annotations on a c10::ClassType.
//
// A c10::ClassType consists of attributes and methods, which are stored in
// arrays (the array elements know their names, but the storage is not keyed on
// the name). For each, we have an array of annotations that parallels the
// corresonding array (of either attributes or methods) held on the
// c10::ClassType.
//
// Note that c10::ClassType is in principle mutable, which can cause
// this data structure to get out of sync with it (this would be a problem with
// parallel arrays or string-keyed data structures). However, in practice the
// types tend to not change after being created from TorchScript.
//
// We make some mild efforts to check for mutation to the underlying, but
// they don't provide firm guarantees. Caveat emptor.
class ClassAnnotation {
public:
ClassAnnotation(c10::ClassTypePtr classType);
// Get the attribute annotations.
// The length and order is the same as `classType->getAttributes()`.
std::vector<AttributeAnnotation> &getAttributeAnnotations();
// Get the method annotations.
// The length and order is the same as `classType->methods()`.
std::vector<MethodAnnotation> &getMethodAnnotations();
std::string toString();
private:
// The c10::ClassType that we are annotating.
//
// Use a shared ptr type to keep the `ClassType *` alive.
// We use a raw ptr as the map key where this class is the map value.
c10::ClassTypePtr classType;
std::vector<AttributeAnnotation> attributeAnnotations;
std::vector<MethodAnnotation> methodAnnotations;
};
// A map of annotations on `c10::ClassType`s
using ClassAnnotationMap =
std::unordered_map<c10::ClassType *, std::unique_ptr<ClassAnnotation>>;
// A collection of class annotations + methods to create the annotations.
class ClassAnnotator {
public:
ClassAnnotator() = default;
// Export the path `exportedPath`, where the root of the traversal
// is at `rootClassType`.
//
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
// have a submodule `a` and that submodule should have a method or attribute
// `b`.
void exportPath(std::vector<std::string> exportedPath,
c10::ClassType &rootClassType);
// Mark everything as not-exported.
//
// This is kind of useless by itself, but together with `exportPath` allows
// exporting a subset of known names out of a larger collection of unknown
// names.
void exportNone(c10::ClassType &rootClassType);
// The annotations collected so far.
const ClassAnnotationMap &getAnnotationMap();
// Get the ClassAnnotation corresponding to `classType`.
ClassAnnotation &getOrCreateClassAnnotation(c10::ClassType *classType);
std::string toString();
private:
ClassAnnotationMap classAnnotations;
};
void initClassAnnotatorBindings(py::module &m);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_CLASS_ANNOTATOR_H

View File

@ -6,6 +6,7 @@
//===----------------------------------------------------------------------===//
#include "ivalue_importer.h"
#include "class_annotator.h"
#include "graph_importer.h"
#include <unordered_map>
@ -93,20 +94,24 @@ namespace {
/// (PyTorch allows this!).
class IValueImporter {
public:
IValueImporter(MlirBlock importBlock, MlirContext context)
: importBlock(importBlock), context(context), typeMapper(context) {}
IValueImporter(MlirBlock importBlock, MlirContext context,
ClassAnnotator &annotator)
: importBlock(importBlock), context(context), typeMapper(context),
annotator(annotator) {}
MlirValue importIValue(c10::IValue value);
private:
MlirValue rawImportIValue(c10::IValue value);
MlirValue importModule(torch::jit::Module jitModule);
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody);
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
const MethodAnnotation &methodAnnotation);
void importClassType(c10::ClassType *classType);
MlirBlock importBlock;
MlirContext context;
TypeMapper typeMapper;
ClassAnnotator &annotator;
// Map tracking already-imported values.
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
@ -274,7 +279,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
void IValueImporter::importMethod(torch::jit::Function *function,
MlirBlock classTypeBody) {
MlirBlock classTypeBody,
const MethodAnnotation &methodAnnotation) {
// We make an effort for the func op's symbol name to be useful for debugging,
// but still clearly non-load-bearing.
std::string symName =
@ -286,13 +292,18 @@ void IValueImporter::importMethod(torch::jit::Function *function,
mlirStringAttrGet(context, toMlirStringRef("private")));
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), func);
c10::optional<MlirNamedAttribute> isPrivate;
if (!methodAnnotation.isExported) {
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
}
createMlirOperationAtEnd(
classTypeBody, "torch.method", mlirLocationUnknownGet(context),
toMlirNamedAttribute(
"name",
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet(
context, toMlirStringRef(symName))));
context, toMlirStringRef(symName))),
isPrivate);
}
void IValueImporter::importClassType(c10::ClassType *classType) {
@ -313,7 +324,17 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
for (const c10::ClassAttribute &classAttribute : classType->getAttributes()) {
ClassAnnotation &classAnnotation =
annotator.getOrCreateClassAnnotation(classType);
const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations();
const auto &classAttributes = classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) {
const c10::ClassAttribute &classAttribute = classAttributes[i];
c10::optional<MlirNamedAttribute> isPrivate;
if (!attributeAnnotations[i].isExported) {
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
}
createMlirOperationAtEnd(
classTypeBody, "torch.attr", loc,
toMlirNamedAttribute(
@ -321,21 +342,24 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
context, toMlirStringRef(classAttribute.getName()))),
toMlirNamedAttribute("type",
mlirTypeAttrGet(typeMapper.mapFromTorchType(
loc, classAttribute.getType()))));
loc, classAttribute.getType()))),
isPrivate);
}
for (torch::jit::Function *function : classType->methods()) {
importMethod(function, classTypeBody);
const auto &methodAnnotations = classAnnotation.getMethodAnnotations();
const auto &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) {
importMethod(methods[i], classTypeBody, methodAnnotations[i]);
}
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
}
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
MlirContext context) {
MlirContext context, ClassAnnotator &annotator) {
// When debugging module importing, it can be useful to dump as so:
// if (ivalue.isModule())
// ivalue.toModule().dump(true, false, false);
IValueImporter importer(block, context);
IValueImporter importer(block, context, annotator);
importer.importIValue(ivalue);
}

View File

@ -12,6 +12,7 @@
#include "../pybind.h"
#include "func_builder.h"
#include "class_annotator.h"
#include "mlir-c/IR.h"
@ -23,7 +24,8 @@ namespace torch_mlir {
/// Main entry-point for importing torch IValue's .
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
void importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context);
void importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
ClassAnnotator &annotator);
} // namespace torch_mlir

View File

@ -15,6 +15,8 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "c10/util/Optional.h"
namespace torch_mlir {
inline MlirStringRef toMlirStringRef(const std::string &s) {
@ -62,6 +64,13 @@ inline void addToMlirOperationState(MlirOperationState &state,
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
}
template <typename T>
void addToMlirOperationState(MlirOperationState &state, c10::optional<T> o) {
if (o.has_value()) {
addToMlirOperationState(state, o.value());
}
}
inline void addToMlirOperationState(MlirOperationState &state) {}
template <typename T, typename U, typename... Ts>

View File

@ -134,9 +134,15 @@ ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
return function;
}
void ModuleBuilder::importModule(torch::jit::Module jitModule) {
void ModuleBuilder::importModule(torch::jit::Module jitModule,
py::object maybeClassAnnotator) {
ClassAnnotator dummyAnnotator;
ClassAnnotator *classAnnotator = &dummyAnnotator;
if (!maybeClassAnnotator.is_none()) {
classAnnotator = py::cast<ClassAnnotator *>(maybeClassAnnotator);
}
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
mlirModuleGetContext(module));
mlirModuleGetContext(module), *classAnnotator);
}
FuncBuilder::Inserter ModuleBuilder::createInserter() {
@ -160,5 +166,6 @@ void ModuleBuilder::bind(py::module &m) {
.def("capture_function", &ModuleBuilder::startCaptureFunction,
py::keep_alive<0, 1>())
.def("import_function", &ModuleBuilder::importFunction)
.def("import_module", &ModuleBuilder::importModule);
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
py::arg("classAnnotator") = py::none());
}

View File

@ -11,6 +11,7 @@
#include "../pybind.h"
#include "acap_dispatch.h"
#include "class_annotator.h"
#include "mlir-c/IR.h"
@ -44,8 +45,11 @@ public:
torch::jit::StrongFunctionPtr
importFunction(torch::jit::StrongFunctionPtr function);
// Imports a torch::jit::Module into the current module.
void importModule(torch::jit::Module jitModule);
// Imports a torch::jit::Module into the current module, using the
// annotations, if not none, provided in `maybeClassAnnotator` which should be
// a ClassAnnotator.
void importModule(torch::jit::Module jitModule,
py::object maybeClassAnnotator);
private:
FuncBuilder::Inserter createInserter();

View File

@ -13,6 +13,7 @@
#include "../init_python_bindings.h"
#include "acap_dispatch.h"
#include "module_builder.h"
#include "class_annotator.h"
using namespace torch_mlir;
namespace py = pybind11;
@ -139,4 +140,6 @@ void torch_mlir::InitBuilderBindings(py::module &m) {
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
ModuleBuilder::bind(m);
initClassAnnotatorBindings(m);
}

View File

@ -12,4 +12,5 @@ from _torch_mlir import *
__all__ = [
"debug_trace_to_stderr",
"ModuleBuilder",
"ClassAnnotator",
]

View File

@ -0,0 +1,74 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
self.exported = 1
self.not_exported = 2
def forward(self):
return self.not_exported_method()
def not_exported_method(self):
return
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = Submodule()
def forward(self):
return self.s.forward()
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()
annotator.exportNone(class_type)
annotator.exportPath(['s', 'exported'], class_type)
annotator.exportPath(['s', 'forward'], class_type)
# "Change detector" test + "documentation" for the repr of `ClassAnnotator`.
# This is semi-load-bearing because users interact with this class and repr
# will show up in error messages, so should be pretty readable.
# CHECK: ClassAnnotator {
# CHECK: ClassAnnotation('__torch__.Submodule') {
# CHECK: AttributeAnnotation('exported') {
# CHECK: isExported = true
# CHECK: }
# CHECK: AttributeAnnotation('not_exported') {
# CHECK: isExported = false
# CHECK: }
# CHECK: MethodAnnotation('forward') {
# CHECK: isExported = true
# CHECK: }
# CHECK: MethodAnnotation('not_exported_method') {
# CHECK: isExported = false
# CHECK: }
# CHECK: }
# CHECK: ClassAnnotation('__torch__.TestModule') {
# CHECK: AttributeAnnotation('s') {
# CHECK: isExported = false
# CHECK: }
# CHECK: MethodAnnotation('forward') {
# CHECK: isExported = false
# CHECK: }
# CHECK: }
# CHECK: }
print(annotator)

View File

@ -0,0 +1,46 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()
try:
annotator.exportPath(['a'], class_type)
except Exception as e:
# CHECK: class '__torch__.TestModule' does not have a method or attribute called 'a'
print(e)
try:
annotator.exportPath([], class_type)
except Exception as e:
# CHECK: Empty exported path. Can only export a property of a class.
print(e)
try:
annotator.exportPath(['a', 'b'], class_type)
except Exception as e:
# This error is generated by PyTorch itself, so be a bit defensive about changes.
# CHECK: __torch__.TestModule {{.*}} 'a'
print(e)
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, annotator)
mb.module.operation.print()

View File

@ -0,0 +1,49 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
self.exported = 1
self.not_exported = 2
def forward(self):
return self.not_exported_method()
def not_exported_method(self):
return
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = Submodule()
def forward(self):
return self.s.forward()
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()
# CHECK-LABEL: torch.class_type @__torch__.Submodule {
# CHECK: torch.attr "exported" : i64
# CHECK: torch.attr private "not_exported" : i64
# CHECK: torch.method "forward", @{{.*}}
# CHECK: torch.method private "not_exported_method", @{{.*}}
# CHECK: }
annotator.exportNone(class_type)
annotator.exportPath(['s', 'exported'], class_type)
annotator.exportPath(['s', 'forward'], class_type)
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, annotator)
mb.module.operation.print()

View File

@ -0,0 +1,41 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.exported = 1
self.not_exported = 2
def forward(self):
return self.not_exported_method()
def not_exported_method(self):
return
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
# CHECK: torch.attr "exported" : i64
# CHECK: torch.attr private "not_exported" : i64
# CHECK: torch.method "forward", @{{.*}}
# CHECK: torch.method private "not_exported_method", @{{.*}}
# CHECK: }
annotator.exportNone(class_type)
annotator.exportPath(['exported'], class_type)
annotator.exportPath(['forward'], class_type)
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, annotator)
mb.module.operation.print()

View File

@ -26,15 +26,27 @@ def main():
help="dump the pytorch module")
parser.add_argument("--import", action="store_true",
help="import the pytorch module")
parser.add_argument("--exported-name", action="append",
help="""
Name to export, such as `my.submodule.forward`(default = export all).
Can pass repeatedly.
""")
args = parser.parse_args()
# TODO: Investigate why "cpu" is needed.
module = torch.jit.load(args.pt_file, map_location="cpu")
mb = torch_mlir.ModuleBuilder()
if args.dump:
module._c.dump(code=True, attrs=False, params=False)
# `import` is a Python keyword, so getattr is needed.
if getattr(args, "import", False):
mb.import_module(module._c)
class_annotator = torch_mlir.ClassAnnotator()
if args.exported_name is not None:
class_annotator.exportNone(module._c._type())
for name in args.exported_name:
class_annotator.exportPath(name.split("."), module._c._type())
mb = torch_mlir.ModuleBuilder()
mb.import_module(module._c, class_annotator)
mb.module.operation.print(large_elements_limit=16)

View File

@ -190,13 +190,24 @@ def Torch_MethodOp : Torch_Op<"method", [
method `name` which calls `function`. `function` is an unbound function.
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
"self" object).
If `private` is present, it indicates that external calls cannot be made
to this method.
}];
let arguments = (ins StrAttr:$name, FlatSymbolRefAttr:$function);
// We don't use sym_visibility because that only applies to Symbol's, and
// some of the related concepts like "nested" visibility are specific to
// symbols.
let arguments = (ins
StrAttr:$name,
FlatSymbolRefAttr:$function,
// `private` is a C++ keyword, so use `isPrivate`.
UnitAttr:$isPrivate
);
let results = (outs);
let assemblyFormat = [{
$name `,` $function attr-dict
(`private` $isPrivate^)? $name `,` $function attr-dict
}];
}
@ -207,13 +218,24 @@ def Torch_AttrOp : Torch_Op<"attr", [
let description = [{
This op declaratively specifies that torch.nn.Module's of the parent
torch.class_type must have an attribute `name` of type `type`.
If `private` is present, it indicates that the value of this attribute
cannot be accessed externally.
}];
let arguments = (ins StrAttr:$name, TypeAttr:$type);
// We don't use sym_visibility because that only applies to Symbol's, and
// some of the related concepts like "nested" visibility are specific to
// symbols.
let arguments = (ins
StrAttr:$name,
TypeAttr:$type,
// `private` is a C++ keyword, so use `isPrivate`
UnitAttr:$isPrivate
);
let results = (outs);
let assemblyFormat = [{
$name `:` $type attr-dict
(`private` $isPrivate^)? $name `:` $type attr-dict
}];
}
@ -233,14 +255,19 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
let description = [{
Represents a slot with global storage. The slot semantics are the same
as Python's: getting or setting a slot is done by object identity.
The `typeBound` is a type that the contained type is a subtype of.
}];
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$typeBound);
let arguments = (ins
SymbolNameAttr:$sym_name,
OptionalAttr<StrAttr>:$sym_visibility,
TypeAttr:$typeBound
);
let results = (outs);
let assemblyFormat = [{
$sym_name attr-dict `:` $typeBound
($sym_visibility^)? $sym_name attr-dict `:` $typeBound
}];
let extraClassDeclaration = [{

View File

@ -188,7 +188,10 @@ ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) {
} else {
auto linkageName = llvm::join(nameStack, ".");
auto globalSlot = globalBuilder.create<GlobalSlotOp>(
attr->getLoc(), linkageName, TypeAttr::get(attr.type()));
attr->getLoc(), linkageName, /*sym_visibility=*/nullptr,
TypeAttr::get(attr.type()));
if (attr.isPrivate())
globalSlot.setVisibility(SymbolTable::Visibility::Private);
AttrOfClass attrOfClass = {classType, attr.name()};
assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end());
globalSlotForAttr[attrOfClass] = globalSlot;
@ -289,6 +292,7 @@ LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
// of methods with a single instance of the corresponding type just gets
// arbitrarily tricky to rewrite. E.g. what if the user creates a list
// of modules, or there is an scf.if selecting between modules, etc.
SmallVector<Operation *> toErase;
auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) {
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op)) {
auto classType = symbolTable.lookup<ClassTypeOp>(
@ -297,9 +301,8 @@ LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
OpBuilder(primSetAttr)
.create<GlobalSlotSetOp>(primSetAttr.getLoc(), globalSlot.sym_name(),
primSetAttr.value());
primSetAttr.erase();
}
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
toErase.push_back(primSetAttr);
} else if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
// If the return value is NnModuleType, then we don't need to do anything.
// Our verification earlier ensured that there are no uses that
// we won't properly rewrite.
@ -317,9 +320,8 @@ LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
globalSlot.sym_name());
primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation());
}
primGetAttr.erase();
}
if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
toErase.push_back(primGetAttr);
} else if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
auto classType = symbolTable.lookup<ClassTypeOp>(primCallMethod.receiver()
.getType()
.cast<NnModuleType>()
@ -334,7 +336,7 @@ LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
.create<CallOp>(primCallMethod.getLoc(), linkageName,
primCallMethod.getType(), newOperands);
primCallMethod.replaceAllUsesWith(call);
primCallMethod.erase();
toErase.push_back(primCallMethod);
}
};
for (auto classType : module.getOps<ClassTypeOp>()) {
@ -345,9 +347,15 @@ LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
FuncOp func = symbolTable.lookup<FuncOp>(method.function());
if (failed(verifyMethodConformsToSubset(func)))
return failure();
if (!method.isPrivate())
func.setVisibility(SymbolTable::Visibility::Public);
func.setName(it->second);
func.walk(rewriteOpWithNnModuleTypeOperand);
for (Operation *op : toErase) {
op->dropAllDefinedValueUses();
op->erase();
}
toErase.clear();
SmallVector<unsigned> argsToErase;
for (auto arg : llvm::enumerate(func.getArguments())) {
if (!arg.value().getType().isa<NnModuleType>())
@ -360,14 +368,17 @@ LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
// as a user of module types.
}
}
return success();
}
void ObjectGraphGlobalizer::removeObjectGraph() {
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
if (isa<ClassTypeOp, NnModuleOp>(op))
if (isa<ClassTypeOp, NnModuleOp>(op)) {
op.dropAllDefinedValueUses();
op.erase();
}
}
}
namespace {

View File

@ -0,0 +1,17 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c {
// CHECK: torch.global_slot "private" @float : f64
torch.attr private "float" : f64
torch.method private "forward", @method
}
// CHECK: func private @forward() {
func private @method(%arg0: !torch.nn.Module<"c">) {
return
}
%c42 = std.constant 42.0 : f64
torch.nn_module {
torch.slot "float", %c42 : f64
} : !torch.nn.Module<"c">