mirror of https://github.com/llvm/torch-mlir
Add trivial inliner interfaces.
With this + manually setting private visibility on everything, a simple classifier can be reduced to this IR, which is looking pretty lean and mean: https://gist.github.com/silvasean/19e7e2e21a61ff197aeac0dd864d188f Also, include a utility script for importing `.pt` models. ``` pt_util.py --import classifier.pt | npcomp-opt -torch-globalize-object-graph ```pull/168/head
parent
cecf1fbba5
commit
8486968925
|
@ -0,0 +1,42 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Utility for handling common tasks for exported `.pt` model files.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Dump PyTorch data structures for .pt file.
|
||||||
|
# This does not involve any MLIR code.
|
||||||
|
$ pt_util.py --dump model.pt
|
||||||
|
|
||||||
|
# Import the .pt file into MLIR.
|
||||||
|
$ pt_util.py --import model.pt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Utility for .pt files")
|
||||||
|
parser.add_argument("pt_file", metavar="PT_FILE", type=str,
|
||||||
|
help="the .pt file to import")
|
||||||
|
parser.add_argument("--dump", action="store_true",
|
||||||
|
help="dump the pytorch module")
|
||||||
|
parser.add_argument("--import", action="store_true",
|
||||||
|
help="import the pytorch module")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# TODO: Investigate why "cpu" is needed.
|
||||||
|
module = torch.jit.load(args.pt_file, map_location="cpu")
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
if args.dump:
|
||||||
|
module._c.dump(code=True, attrs=False, params=False)
|
||||||
|
# `import` is a Python keyword, so getattr is needed.
|
||||||
|
if getattr(args, "import", False):
|
||||||
|
mb.import_module(module._c)
|
||||||
|
mb.module.operation.print(large_elements_limit=16)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
|
@ -15,6 +16,28 @@ using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::Basicpy;
|
using namespace mlir::NPCOMP::Basicpy;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct BasicpyInlinerInterface : public DialectInlinerInterface {
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect Class
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void BasicpyDialect::initialize() {
|
void BasicpyDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
|
@ -22,6 +45,7 @@ void BasicpyDialect::initialize() {
|
||||||
>();
|
>();
|
||||||
addTypes<BoolType, BytesType, DictType, EllipsisType, ListType, NoneType,
|
addTypes<BoolType, BytesType, DictType, EllipsisType, ListType, NoneType,
|
||||||
SlotObjectType, StrType, TupleType, UnknownType>();
|
SlotObjectType, StrType, TupleType, UnknownType>();
|
||||||
|
addInterfaces<BasicpyInlinerInterface>();
|
||||||
|
|
||||||
// TODO: Make real ops for everything we need.
|
// TODO: Make real ops for everything we need.
|
||||||
allowUnknownOperations();
|
allowUnknownOperations();
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
@ -16,6 +17,24 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct TorchInlinerInterface : public DialectInlinerInterface {
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tablegen Type Definitions
|
// Tablegen Type Definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -32,6 +51,7 @@ void TorchDialect::initialize() {
|
||||||
#define GET_TYPEDEF_LIST
|
#define GET_TYPEDEF_LIST
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
addInterfaces<TorchInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
||||||
|
|
Loading…
Reference in New Issue