mirror of https://github.com/llvm/torch-mlir
Add trivial inliner interfaces.
With this + manually setting private visibility on everything, a simple classifier can be reduced to this IR, which is looking pretty lean and mean: https://gist.github.com/silvasean/19e7e2e21a61ff197aeac0dd864d188f Also, include a utility script for importing `.pt` models. ``` pt_util.py --import classifier.pt | npcomp-opt -torch-globalize-object-graph ```pull/168/head
parent
cecf1fbba5
commit
8486968925
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Utility for handling common tasks for exported `.pt` model files.
|
||||
|
||||
Usage:
|
||||
# Dump PyTorch data structures for .pt file.
|
||||
# This does not involve any MLIR code.
|
||||
$ pt_util.py --dump model.pt
|
||||
|
||||
# Import the .pt file into MLIR.
|
||||
$ pt_util.py --import model.pt
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Utility for .pt files")
|
||||
parser.add_argument("pt_file", metavar="PT_FILE", type=str,
|
||||
help="the .pt file to import")
|
||||
parser.add_argument("--dump", action="store_true",
|
||||
help="dump the pytorch module")
|
||||
parser.add_argument("--import", action="store_true",
|
||||
help="import the pytorch module")
|
||||
args = parser.parse_args()
|
||||
# TODO: Investigate why "cpu" is needed.
|
||||
module = torch.jit.load(args.pt_file, map_location="cpu")
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
if args.dump:
|
||||
module._c.dump(code=True, attrs=False, params=False)
|
||||
# `import` is a Python keyword, so getattr is needed.
|
||||
if getattr(args, "import", False):
|
||||
mb.import_module(module._c)
|
||||
mb.module.operation.print(large_elements_limit=16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
|
@ -15,6 +16,28 @@ using namespace mlir;
|
|||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Basicpy;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct BasicpyInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BasicpyDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
@ -22,6 +45,7 @@ void BasicpyDialect::initialize() {
|
|||
>();
|
||||
addTypes<BoolType, BytesType, DictType, EllipsisType, ListType, NoneType,
|
||||
SlotObjectType, StrType, TupleType, UnknownType>();
|
||||
addInterfaces<BasicpyInlinerInterface>();
|
||||
|
||||
// TODO: Make real ops for everything we need.
|
||||
allowUnknownOperations();
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -16,6 +17,24 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TorchInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tablegen Type Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -32,6 +51,7 @@ void TorchDialect::initialize() {
|
|||
#define GET_TYPEDEF_LIST
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TorchInlinerInterface>();
|
||||
}
|
||||
|
||||
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
||||
|
|
Loading…
Reference in New Issue