torch_mlir.compile: allow custom backend_legal_ops set

Allow customizing `backend_legal_ops` for "torch" output type, since we
don't know which backend will be used (it might be a custom backend).
We don't allow customizing the `backend_legal_ops` for the other output
types (Linalg, TOSA, MHLO) since those backends control their set of
legal ops directly.

Fixes #1418
pull/1476/head
Sean Silva 2022-10-11 15:44:16 +00:00
parent 61db1b5c4d
commit 6403c0e56f
2 changed files with 41 additions and 2 deletions

View File

@ -0,0 +1,23 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch_mlir
class AddmmModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y, z):
return torch.ops.aten.addmm(x, y, z)
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
print(torch_mlir.compile(AddmmModule(), example_args,
output_type="torch", backend_legal_ops=["torch.aten.addmm"]))
# CHECK-LABEL: @forward
# CHECK: torch.aten.addmm

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Sequence, Union, List
from typing import Optional, Sequence, Union, List
from enum import Enum
import sys
@ -140,6 +140,7 @@ def compile(model: torch.nn.Module,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
use_tracing: bool = False,
ignore_traced_shapes = False,
backend_legal_ops: Optional[Sequence[str]] = None,
verbose: bool = False):
"""Convert a PyTorch model to MLIR.
@ -161,6 +162,9 @@ def compile(model: torch.nn.Module,
`TensorPlaceholder`'s used as `example_args`. Also,
strictly-speaking, this option covers dtypes too, but we just say
"shapes" to be succinct.
backend_legal_ops: A list of ops that should be considered legal for
the backend. An op that is considered legal will not be decomposed.
This option is only valid with the `"torch"` output type.
verbose: If true, print extra information about the conversion.
Returns:
@ -171,6 +175,19 @@ def compile(model: torch.nn.Module,
if ignore_traced_shapes and not use_tracing:
raise Exception("`ignore_traced_shapes` requires `use_tracing`")
# We only allow `backend_legal_ops` to be specified for the `"torch"`
# output type because the other output types actually invoke their
# respective backends (Linalg, TOSA, or MHLO), and those backends have
# very specific requirements about the ops which are legal.
# See `BACKEND_LEGAL_OPS` for more details.
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception("`backend_legal_ops` is only valid with the "
"`torch` output type")
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
# Special case -- many models have just one input, so canonicalize a single
# tensor to a list of a single tensor to make the API more ergonomic.
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
@ -253,7 +270,6 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
if output_type == OutputType.RAW:
return mb.module
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
run_pipeline_with_repro_report(
mb.module,