From 1106b9aeae867f1ed44fd8f90abf140fc8f9534c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E5=AE=B6=E4=BC=9F?= <73166454+Vremold@users.noreply.github.com> Date: Wed, 24 Aug 2022 07:44:36 +0800 Subject: [PATCH] [MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266) Co-authored-by: Bairen Yi Co-authored-by: Jiawei Wu Co-authored-by: Tianyou Guo Co-authored-by: Xu Yan Co-authored-by: Ziheng Jiang Co-authored-by: Vremold --- examples/torchscript_mhlo_backend_resnet.py | 14 +++++++++++ examples/torchscript_mhlo_backend_tinybert.py | 24 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 examples/torchscript_mhlo_backend_resnet.py create mode 100644 examples/torchscript_mhlo_backend_tinybert.py diff --git a/examples/torchscript_mhlo_backend_resnet.py b/examples/torchscript_mhlo_backend_resnet.py new file mode 100644 index 000000000..bb481f6c3 --- /dev/null +++ b/examples/torchscript_mhlo_backend_resnet.py @@ -0,0 +1,14 @@ +import torch +import torchvision.models as models +import torch_mlir + +model = models.resnet18(pretrained=True) +model.eval() +data = torch.randn(2,3,200,200) +out_mhlo_mlir_path = "./resnet18_mhlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}") diff --git a/examples/torchscript_mhlo_backend_tinybert.py b/examples/torchscript_mhlo_backend_tinybert.py new file mode 100644 index 000000000..62827361e --- /dev/null +++ b/examples/torchscript_mhlo_backend_tinybert.py @@ -0,0 +1,24 @@ +import torch +import torch_mlir + +from transformers import BertForMaskedLM + +# Wrap the bert model to avoid multiple returns problem +class BertTinyWrapper(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False) + + def forward(self, data): + return self.bert(data)[0] + +model = BertTinyWrapper() +model.eval() +data = torch.randint(30522, (2, 128)) +out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")