mirror of https://github.com/llvm/torch-mlir
[torchdynamo] Use decompositions to support a few ops
parent
b4b92c990e
commit
88db99946b
|
@ -64,6 +64,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"NllLossModuleBackward1DSum_basic",
|
||||
"NllLossModuleBackward1DWeight_basic",
|
||||
"NllLossModuleBackward1D_basic",
|
||||
# TypeError: uniform() missing 2 required keyword-only arguments: 'dtype' and 'device'
|
||||
# RuntimeError: Failed running call_function aten.uniform(...
|
||||
# https://github.com/pytorch/torchdynamo/issues/1954
|
||||
"UniformNoCorrelationModule_basic",
|
||||
# Decomposition assertion:
|
||||
# assert device is not None or dtype is not None or memory_format is not None
|
||||
# https://github.com/pytorch/pytorch/issues/89633
|
||||
|
@ -78,24 +82,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# These are probably due to slightly different ops being recorded by
|
||||
# torchdynamo vs. torchscript.
|
||||
|
||||
# error: unsupported by backend contract: tensor with unknown rank
|
||||
# %3 = torch.operator "aten._adaptive_avg_pool2d"(%1, %2) : (!torch.tensor<*,f32>, !torch.list<int>) -> !torch.tensor
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
# No upstream decompositions.
|
||||
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
||||
# See also: https://github.com/pytorch/torchdynamo/issues/327
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
# %42 = torch.operator "aten.std.correction"(%7, %none, %int1, %false) : (!torch.tensor<*,f64>, !torch.none, !torch.int, !torch.bool) -> !torch.tensor
|
||||
"BernoulliModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"StdBiasedModule_basic",
|
||||
"StdDimBiasedModule_basic",
|
||||
"StdDimEmptyDimModule_basic",
|
||||
"StdDimKeepDimFalseModule_basic",
|
||||
"StdDimKeepDimTrueModule_basic",
|
||||
"StdDimNoneDimModule_basic",
|
||||
"StdUnbiasedModule_basic",
|
||||
"UniformModule_basic",
|
||||
"UniformNoCorrelationModule_basic",
|
||||
# %1 = torch.operator "aten.scalar_tensor"(%float8.000000e00, %int6, %int0, %cpu, %none) : (!torch.float, !torch.int, !torch.int, !torch.Device, !torch.none) -> !torch.tensor
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"ElementwiseWhereScalarOtherModule_basic",
|
||||
|
@ -109,20 +99,18 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
# %4 = torch.operator "aten.dot"(%1, %3) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>) -> !torch.tensor
|
||||
"Matmul_dot",
|
||||
# %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor
|
||||
"Matmul_vecmat",
|
||||
|
||||
# https://github.com/llvm/torch-mlir/issues/1611
|
||||
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
|
||||
"BernoulliFloatModule_basic",
|
||||
# error: failed to legalize operation 'torch.aten.bernoulli.Tensor' that was explicitly marked illegal
|
||||
"BernoulliTensorModule_basic",
|
||||
# error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"UniformModule_basic",
|
||||
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
|
||||
"TModuleRank1_basic",
|
||||
}
|
||||
|
|
|
@ -8,13 +8,39 @@ from typing import List
|
|||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from functorch._src.compile_utils import strip_overloads
|
||||
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
import warnings
|
||||
# https://github.com/pytorch/pytorch/issues/89064
|
||||
warnings.filterwarnings("ignore", module="torch.jit._check")
|
||||
|
||||
|
||||
def _get_decomposition_table():
|
||||
"""Get a decomposition table suitable for Torch-MLIR.
|
||||
|
||||
Sometimes TorchDynamo traces slightly different ops than what TorchScript
|
||||
captures. Historically we have been driven by the ops captured by
|
||||
TorchScript, so we try to decompose the ops captured by TorchDynamo into
|
||||
other ops that we already support.
|
||||
|
||||
There isn't a highly principled solution here. Torch-MLIR currently supports
|
||||
a somewhat random set of ops, added in a demand-driven way over time,
|
||||
including direct backend support and decompositions internal to Torch-MLIR.
|
||||
As described in the
|
||||
[long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md),
|
||||
eventually this situation is expected to be made a lot more principled
|
||||
by aligning more with how Torch-MLIR would have looked if some of the new
|
||||
upstream PyTorch infra had been available at the beginning -- in particular
|
||||
the new decomposition infra and PrimTorch.
|
||||
"""
|
||||
aten = torch.ops.aten
|
||||
return get_decompositions([
|
||||
aten._adaptive_avg_pool2d,
|
||||
aten.std.correction,
|
||||
aten.dot,
|
||||
])
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""Canonicalize single-element tuple returns to just the element.
|
||||
|
||||
|
@ -55,7 +81,8 @@ def make_simple_dynamo_backend(user_backend):
|
|||
def wrapper_backend(fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
did_unwrap = _unwrap_single_tuple_return(fx_graph)
|
||||
dispatcher_ops = make_fx(fx_graph)(*example_inputs)
|
||||
dispatcher_ops = make_fx(
|
||||
fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs)
|
||||
strip_overloads(dispatcher_ops)
|
||||
user_callable = user_backend(dispatcher_ops, example_inputs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue