2021-02-19 10:31:06 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
"""
|
|
|
|
Utility for handling common tasks for exported `.pt` model files.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
# Dump PyTorch data structures for .pt file.
|
|
|
|
# This does not involve any MLIR code.
|
|
|
|
$ pt_util.py --dump model.pt
|
|
|
|
|
|
|
|
# Import the .pt file into MLIR.
|
|
|
|
$ pt_util.py --import model.pt
|
|
|
|
"""
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch_mlir
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Utility for .pt files")
|
|
|
|
parser.add_argument("pt_file", metavar="PT_FILE", type=str,
|
|
|
|
help="the .pt file to import")
|
|
|
|
parser.add_argument("--dump", action="store_true",
|
|
|
|
help="dump the pytorch module")
|
|
|
|
parser.add_argument("--import", action="store_true",
|
|
|
|
help="import the pytorch module")
|
2021-02-20 08:21:21 +08:00
|
|
|
parser.add_argument("--exported-name", action="append",
|
|
|
|
help="""
|
|
|
|
Name to export, such as `my.submodule.forward`(default = export all).
|
|
|
|
Can pass repeatedly.
|
|
|
|
""")
|
2021-02-19 10:31:06 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
# TODO: Investigate why "cpu" is needed.
|
|
|
|
module = torch.jit.load(args.pt_file, map_location="cpu")
|
2021-02-20 08:21:21 +08:00
|
|
|
|
2021-02-19 10:31:06 +08:00
|
|
|
if args.dump:
|
|
|
|
module._c.dump(code=True, attrs=False, params=False)
|
2021-02-20 08:21:21 +08:00
|
|
|
|
2021-02-19 10:31:06 +08:00
|
|
|
# `import` is a Python keyword, so getattr is needed.
|
|
|
|
if getattr(args, "import", False):
|
2021-02-20 08:21:21 +08:00
|
|
|
class_annotator = torch_mlir.ClassAnnotator()
|
|
|
|
if args.exported_name is not None:
|
|
|
|
class_annotator.exportNone(module._c._type())
|
|
|
|
for name in args.exported_name:
|
|
|
|
class_annotator.exportPath(name.split("."), module._c._type())
|
|
|
|
mb = torch_mlir.ModuleBuilder()
|
|
|
|
mb.import_module(module._c, class_annotator)
|
2021-02-19 10:31:06 +08:00
|
|
|
mb.module.operation.print(large_elements_limit=16)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|