Add trivial inliner interfaces.

With this + manually setting private visibility on everything, a simple
classifier can be reduced to this IR, which is looking pretty lean and
mean:
https://gist.github.com/silvasean/19e7e2e21a61ff197aeac0dd864d188f

Also, include a utility script for importing `.pt` models.

```
pt_util.py --import classifier.pt | npcomp-opt -torch-globalize-object-graph
```
pull/168/head
Sean Silva 2021-02-18 18:31:06 -08:00
parent cecf1fbba5
commit 8486968925
3 changed files with 86 additions and 0 deletions

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python3
"""
Utility for handling common tasks for exported `.pt` model files.
Usage:
# Dump PyTorch data structures for .pt file.
# This does not involve any MLIR code.
$ pt_util.py --dump model.pt
# Import the .pt file into MLIR.
$ pt_util.py --import model.pt
"""
import torch
import torch_mlir
import argparse
def main():
parser = argparse.ArgumentParser(
description="Utility for .pt files")
parser.add_argument("pt_file", metavar="PT_FILE", type=str,
help="the .pt file to import")
parser.add_argument("--dump", action="store_true",
help="dump the pytorch module")
parser.add_argument("--import", action="store_true",
help="import the pytorch module")
args = parser.parse_args()
# TODO: Investigate why "cpu" is needed.
module = torch.jit.load(args.pt_file, map_location="cpu")
mb = torch_mlir.ModuleBuilder()
if args.dump:
module._c.dump(code=True, attrs=False, params=False)
# `import` is a Python keyword, so getattr is needed.
if getattr(args, "import", False):
mb.import_module(module._c)
mb.module.operation.print(large_elements_limit=16)
if __name__ == "__main__":
main()

View File

@ -8,6 +8,7 @@
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "llvm/ADT/TypeSwitch.h"
@ -15,6 +16,28 @@ using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Basicpy;
//===----------------------------------------------------------------------===//
// Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct BasicpyInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Dialect Class
//===----------------------------------------------------------------------===//
void BasicpyDialect::initialize() {
addOperations<
#define GET_OP_LIST
@ -22,6 +45,7 @@ void BasicpyDialect::initialize() {
>();
addTypes<BoolType, BytesType, DictType, EllipsisType, ListType, NoneType,
SlotObjectType, StrType, TupleType, UnknownType>();
addInterfaces<BasicpyInlinerInterface>();
// TODO: Make real ops for everything we need.
allowUnknownOperations();

View File

@ -8,6 +8,7 @@
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/StringExtras.h"
@ -16,6 +17,24 @@
using namespace mlir;
using namespace mlir::NPCOMP::Torch;
//===----------------------------------------------------------------------===//
// Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TorchInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Tablegen Type Definitions
//===----------------------------------------------------------------------===//
@ -32,6 +51,7 @@ void TorchDialect::initialize() {
#define GET_TYPEDEF_LIST
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
>();
addInterfaces<TorchInlinerInterface>();
}
Type TorchDialect::parseType(DialectAsmParser &parser) const {