diff --git a/examples/torchscript_mhlo_backend_resnet.py b/examples/torchscript_mhlo_backend_resnet.py new file mode 100644 index 000000000..bb481f6c3 --- /dev/null +++ b/examples/torchscript_mhlo_backend_resnet.py @@ -0,0 +1,14 @@ +import torch +import torchvision.models as models +import torch_mlir + +model = models.resnet18(pretrained=True) +model.eval() +data = torch.randn(2,3,200,200) +out_mhlo_mlir_path = "./resnet18_mhlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}") diff --git a/examples/torchscript_mhlo_backend_tinybert.py b/examples/torchscript_mhlo_backend_tinybert.py new file mode 100644 index 000000000..62827361e --- /dev/null +++ b/examples/torchscript_mhlo_backend_tinybert.py @@ -0,0 +1,24 @@ +import torch +import torch_mlir + +from transformers import BertForMaskedLM + +# Wrap the bert model to avoid multiple returns problem +class BertTinyWrapper(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False) + + def forward(self, data): + return self.bert(data)[0] + +model = BertTinyWrapper() +model.eval() +data = torch.randint(30522, (2, 128)) +out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")