2022-12-23 00:39:55 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
|
|
|
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
|
|
|
|
import torch
|
2024-02-07 11:07:59 +08:00
|
|
|
from torch_mlir import torchscript
|
2022-12-23 00:39:55 +08:00
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
|
2022-12-23 00:39:55 +08:00
|
|
|
class BasicModule(torch.nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.ops.aten.sin(x)
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
|
2022-12-23 00:39:55 +08:00
|
|
|
example_arg = torch.ones(2, 3)
|
2024-02-07 11:07:59 +08:00
|
|
|
example_args = torchscript.ExampleArgs.get(example_arg)
|
2022-12-23 00:39:55 +08:00
|
|
|
|
|
|
|
traced = torch.jit.trace(BasicModule(), example_arg)
|
2024-02-07 11:07:59 +08:00
|
|
|
print(torchscript.compile(traced, example_args))
|
2022-12-23 00:39:55 +08:00
|
|
|
# CHECK: module
|
|
|
|
# CHECK-DAG: func.func @forward
|
|
|
|
|
|
|
|
traced = torch.jit.trace(BasicModule(), example_arg)
|
|
|
|
try:
|
|
|
|
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
2024-04-28 05:16:31 +08:00
|
|
|
torchscript.compile(
|
|
|
|
traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg)
|
|
|
|
)
|
2022-12-23 00:39:55 +08:00
|
|
|
except Exception as e:
|
|
|
|
print(e)
|