mirror of https://github.com/llvm/torch-mlir
Allow passing traced `torch.nn.Module`s into `torch_mlir.compile` (#1743)
This commit adds support for passing to `torch_mlir.compile` the result of running `torch.jit.trace` on a model by relaxing the condition that checks if the model is already in JIT IR to allow any `torch.jit.ScriptModule`. Fixes https://github.com/llvm/torch-mlir/issues/1739pull/1750/head snapshot-20221223.696
parent
52669cbbd5
commit
3260a1ea6e
|
@ -0,0 +1,28 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
class BasicModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.sin(x)
|
||||
|
||||
example_arg = torch.ones(2, 3)
|
||||
example_args = torch_mlir.ExampleArgs.get(example_arg)
|
||||
|
||||
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||
print(torch_mlir.compile(traced, example_args))
|
||||
# CHECK: module
|
||||
# CHECK-DAG: func.func @forward
|
||||
|
||||
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||
try:
|
||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||
torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg))
|
||||
except Exception as e:
|
||||
print(e)
|
|
@ -314,7 +314,7 @@ def compile(model: torch.nn.Module,
|
|||
# backend. This separation should be visible at the Python API level, and
|
||||
# we can implement a deliberately simplified API like `torch_mlir.compile`
|
||||
# on top of those building blocks.
|
||||
if isinstance(model, torch.jit._script.RecursiveScriptModule):
|
||||
if isinstance(model, torch.jit.ScriptModule):
|
||||
# If the user already converted the model to JIT IR themselves, just
|
||||
# do some basic error checking, but take the model as-is.
|
||||
for method_name in example_args._get_methods():
|
||||
|
|
Loading…
Reference in New Issue