mirror of https://github.com/llvm/torch-mlir
[MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> Co-authored-by: Vremold <xremold@gamil.com>pull/1261/head
parent
2374098d71
commit
1106b9aeae
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.eval()
|
||||
data = torch.randn(2,3,200,200)
|
||||
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
|
||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
import torch_mlir
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
# Wrap the bert model to avoid multiple returns problem
|
||||
class BertTinyWrapper(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False)
|
||||
|
||||
def forward(self, data):
|
||||
return self.bert(data)[0]
|
||||
|
||||
model = BertTinyWrapper()
|
||||
model.eval()
|
||||
data = torch.randint(30522, (2, 128))
|
||||
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
|
||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
|
Loading…
Reference in New Issue