diff --git a/python/test/compile_api/already_traced.py b/python/test/compile_api/already_traced.py new file mode 100644 index 000000000..a719eb743 --- /dev/null +++ b/python/test/compile_api/already_traced.py @@ -0,0 +1,28 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch_mlir + +class BasicModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.sin(x) + +example_arg = torch.ones(2, 3) +example_args = torch_mlir.ExampleArgs.get(example_arg) + +traced = torch.jit.trace(BasicModule(), example_arg) +print(torch_mlir.compile(traced, example_args)) +# CHECK: module +# CHECK-DAG: func.func @forward + +traced = torch.jit.trace(BasicModule(), example_arg) +try: + # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. + torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg)) +except Exception as e: + print(e) diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 9bcf4ada2..3f08bb173 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -314,7 +314,7 @@ def compile(model: torch.nn.Module, # backend. This separation should be visible at the Python API level, and # we can implement a deliberately simplified API like `torch_mlir.compile` # on top of those building blocks. - if isinstance(model, torch.jit._script.RecursiveScriptModule): + if isinstance(model, torch.jit.ScriptModule): # If the user already converted the model to JIT IR themselves, just # do some basic error checking, but take the model as-is. for method_name in example_args._get_methods():