[torchdynamo] Use decompositions to support a few ops

pull/1649/head
Sean Silva 2022-11-29 14:01:42 +00:00
parent b4b92c990e
commit 88db99946b
2 changed files with 37 additions and 22 deletions

View File

@ -64,6 +64,10 @@ TORCHDYNAMO_XFAIL_SET = {
"NllLossModuleBackward1DSum_basic",
"NllLossModuleBackward1DWeight_basic",
"NllLossModuleBackward1D_basic",
# TypeError: uniform() missing 2 required keyword-only arguments: 'dtype' and 'device'
# RuntimeError: Failed running call_function aten.uniform(...
# https://github.com/pytorch/torchdynamo/issues/1954
"UniformNoCorrelationModule_basic",
# Decomposition assertion:
# assert device is not None or dtype is not None or memory_format is not None
# https://github.com/pytorch/pytorch/issues/89633
@ -78,24 +82,10 @@ TORCHDYNAMO_XFAIL_SET = {
# These are probably due to slightly different ops being recorded by
# torchdynamo vs. torchscript.
# error: unsupported by backend contract: tensor with unknown rank
# %3 = torch.operator "aten._adaptive_avg_pool2d"(%1, %2) : (!torch.tensor<*,f32>, !torch.list<int>) -> !torch.tensor
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
# No upstream decompositions.
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
# See also: https://github.com/pytorch/torchdynamo/issues/327
"AtenEmbeddingBagSumExample_basic",
# %42 = torch.operator "aten.std.correction"(%7, %none, %int1, %false) : (!torch.tensor<*,f64>, !torch.none, !torch.int, !torch.bool) -> !torch.tensor
"BernoulliModule_basic",
"DropoutTrainModule_basic",
"StdBiasedModule_basic",
"StdDimBiasedModule_basic",
"StdDimEmptyDimModule_basic",
"StdDimKeepDimFalseModule_basic",
"StdDimKeepDimTrueModule_basic",
"StdDimNoneDimModule_basic",
"StdUnbiasedModule_basic",
"UniformModule_basic",
"UniformNoCorrelationModule_basic",
# %1 = torch.operator "aten.scalar_tensor"(%float8.000000e00, %int6, %int0, %cpu, %none) : (!torch.float, !torch.int, !torch.int, !torch.Device, !torch.none) -> !torch.tensor
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
@ -109,20 +99,18 @@ TORCHDYNAMO_XFAIL_SET = {
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
# %4 = torch.operator "aten.dot"(%1, %3) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>) -> !torch.tensor
"Matmul_dot",
# %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor
"Matmul_vecmat",
# https://github.com/llvm/torch-mlir/issues/1611
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
"Aten_EmbeddingBagExample_basic",
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
"BernoulliFloatModule_basic",
# error: failed to legalize operation 'torch.aten.bernoulli.Tensor' that was explicitly marked illegal
"BernoulliTensorModule_basic",
# error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
"ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic",
"UniformModule_basic",
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
"TModuleRank1_basic",
}

View File

@ -8,13 +8,39 @@ from typing import List
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from functorch._src.compile_utils import strip_overloads
from torch._decomp import get_decompositions
import warnings
# https://github.com/pytorch/pytorch/issues/89064
warnings.filterwarnings("ignore", module="torch.jit._check")
def _get_decomposition_table():
"""Get a decomposition table suitable for Torch-MLIR.
Sometimes TorchDynamo traces slightly different ops than what TorchScript
captures. Historically we have been driven by the ops captured by
TorchScript, so we try to decompose the ops captured by TorchDynamo into
other ops that we already support.
There isn't a highly principled solution here. Torch-MLIR currently supports
a somewhat random set of ops, added in a demand-driven way over time,
including direct backend support and decompositions internal to Torch-MLIR.
As described in the
[long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md),
eventually this situation is expected to be made a lot more principled
by aligning more with how Torch-MLIR would have looked if some of the new
upstream PyTorch infra had been available at the beginning -- in particular
the new decomposition infra and PrimTorch.
"""
aten = torch.ops.aten
return get_decompositions([
aten._adaptive_avg_pool2d,
aten.std.correction,
aten.dot,
])
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""Canonicalize single-element tuple returns to just the element.
@ -55,7 +81,8 @@ def make_simple_dynamo_backend(user_backend):
def wrapper_backend(fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
did_unwrap = _unwrap_single_tuple_return(fx_graph)
dispatcher_ops = make_fx(fx_graph)(*example_inputs)
dispatcher_ops = make_fx(
fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs)
strip_overloads(dispatcher_ops)
user_callable = user_backend(dispatcher_ops, example_inputs)