[MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266)

Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
Co-authored-by: Vremold <xremold@gamil.com>
pull/1261/head
武家伟 2022-08-24 07:44:36 +08:00 committed by GitHub
parent 2374098d71
commit 1106b9aeae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 0 deletions

View File

@ -0,0 +1,14 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")

View File

@ -0,0 +1,24 @@
import torch
import torch_mlir
from transformers import BertForMaskedLM
# Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False)
def forward(self, data):
return self.bert(data)[0]
model = BertTinyWrapper()
model.eval()
data = torch.randint(30522, (2, 128))
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")