mirror of https://github.com/llvm/torch-mlir
torch_mlir.compile: allow custom backend_legal_ops set
Allow customizing `backend_legal_ops` for "torch" output type, since we don't know which backend will be used (it might be a custom backend). We don't allow customizing the `backend_legal_ops` for the other output types (Linalg, TOSA, MHLO) since those backends control their set of legal ops directly. Fixes #1418pull/1476/head
parent
61db1b5c4d
commit
6403c0e56f
|
@ -0,0 +1,23 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
|
||||
class AddmmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, x, y, z):
|
||||
return torch.ops.aten.addmm(x, y, z)
|
||||
|
||||
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
|
||||
|
||||
print(torch_mlir.compile(AddmmModule(), example_args,
|
||||
output_type="torch", backend_legal_ops=["torch.aten.addmm"]))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.addmm
|
|
@ -3,7 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Sequence, Union, List
|
||||
from typing import Optional, Sequence, Union, List
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
|
@ -140,6 +140,7 @@ def compile(model: torch.nn.Module,
|
|||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
use_tracing: bool = False,
|
||||
ignore_traced_shapes = False,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
verbose: bool = False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
|
@ -161,6 +162,9 @@ def compile(model: torch.nn.Module,
|
|||
`TensorPlaceholder`'s used as `example_args`. Also,
|
||||
strictly-speaking, this option covers dtypes too, but we just say
|
||||
"shapes" to be succinct.
|
||||
backend_legal_ops: A list of ops that should be considered legal for
|
||||
the backend. An op that is considered legal will not be decomposed.
|
||||
This option is only valid with the `"torch"` output type.
|
||||
verbose: If true, print extra information about the conversion.
|
||||
|
||||
Returns:
|
||||
|
@ -171,6 +175,19 @@ def compile(model: torch.nn.Module,
|
|||
if ignore_traced_shapes and not use_tracing:
|
||||
raise Exception("`ignore_traced_shapes` requires `use_tracing`")
|
||||
|
||||
# We only allow `backend_legal_ops` to be specified for the `"torch"`
|
||||
# output type because the other output types actually invoke their
|
||||
# respective backends (Linalg, TOSA, or MHLO), and those backends have
|
||||
# very specific requirements about the ops which are legal.
|
||||
# See `BACKEND_LEGAL_OPS` for more details.
|
||||
if backend_legal_ops is not None:
|
||||
if output_type != OutputType.TORCH:
|
||||
raise Exception("`backend_legal_ops` is only valid with the "
|
||||
"`torch` output type")
|
||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
# Special case -- many models have just one input, so canonicalize a single
|
||||
# tensor to a list of a single tensor to make the API more ergonomic.
|
||||
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
|
||||
|
@ -253,7 +270,6 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
|
|
Loading…
Reference in New Issue