mirror of https://github.com/llvm/torch-mlir
[MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> Co-authored-by: Vremold <xremold@gamil.com>pull/1261/head
parent
2374098d71
commit
1106b9aeae
|
@ -0,0 +1,14 @@
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
model = models.resnet18(pretrained=True)
|
||||||
|
model.eval()
|
||||||
|
data = torch.randn(2,3,200,200)
|
||||||
|
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
|
||||||
|
|
||||||
|
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
|
||||||
|
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
|
outf.write(str(module))
|
||||||
|
|
||||||
|
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")
|
|
@ -0,0 +1,24 @@
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
from transformers import BertForMaskedLM
|
||||||
|
|
||||||
|
# Wrap the bert model to avoid multiple returns problem
|
||||||
|
class BertTinyWrapper(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
return self.bert(data)[0]
|
||||||
|
|
||||||
|
model = BertTinyWrapper()
|
||||||
|
model.eval()
|
||||||
|
data = torch.randint(30522, (2, 128))
|
||||||
|
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
|
||||||
|
|
||||||
|
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
|
||||||
|
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
|
outf.write(str(module))
|
||||||
|
|
||||||
|
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
|
Loading…
Reference in New Issue