mirror of https://github.com/llvm/torch-mlir
Pure-Python FX importer. (#2098)
Co-authored-by: Sean Silva <silvasean@google.com>pull/2119/head snapshot-20230512.836
parent
e161f2511a
commit
c3cd7471b4
|
@ -106,7 +106,7 @@ def main():
|
|||
xfail_set = LTC_XFAIL_SET
|
||||
crashing_set = set()
|
||||
elif args.config == "torchdynamo":
|
||||
config = TorchDynamoTestConfig()
|
||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||
crashing_set = TORCHDYNAMO_CRASHING_SET
|
||||
|
||||
|
|
|
@ -56,20 +56,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ElementwiseWhereScalarSelfModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
||||
# %7 = torch.operator "aten._index_put_impl_.hacked_twin"(%1, %6, %5, %true, %false) : (!torch.tensor<*,f32>, !torch.list<tensor>, !torch.tensor<*,f32>, !torch.bool, !torch.bool) -> !torch.tensor
|
||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DIndexModule_basic",
|
||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
|
||||
# https://github.com/llvm/torch-mlir/issues/1611
|
||||
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
|
||||
"BernoulliFloatModule_basic",
|
||||
"BernoulliPModule_basic",
|
||||
|
@ -77,8 +64,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"UniformModule_basic",
|
||||
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
|
||||
"TModuleRank1_basic",
|
||||
# error: unsupported by backend contract: tensor with unknown rank
|
||||
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
|
||||
"ElementwisePreluModule_basic",
|
||||
|
@ -107,7 +92,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'TensorToBoolZeroRank_basic',
|
||||
'TensorToBool_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
'AtenSubFloatModule_basic',
|
||||
'BoolFloatFalseModule_basic',
|
||||
'BoolFloatTrueModule_basic',
|
||||
|
@ -120,8 +105,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'SubFloatModule_basic',
|
||||
'TensorToFloatZeroRank_basic',
|
||||
'TensorToFloat_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
|
||||
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
'AddIntModule_basic',
|
||||
'AtenIntTensorCharDtypeModule_basic',
|
||||
'BoolIntFalseModule_basic',
|
||||
|
@ -138,78 +125,153 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'TensorToInt_basic',
|
||||
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
|
||||
# torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
|
||||
'SortIntListReverse_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
|
||||
'SortIntList_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
'AtenFloatScalarModule_basic',
|
||||
'AtenIntBoolOpModule_basic',
|
||||
'OneHotModule_basic',
|
||||
'QuantizedMLP_basic',
|
||||
'ScalarImplicitFloatModule_basic',
|
||||
'ScalarImplicitIntModule_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
|
||||
# torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
||||
'BincountMinlengthModule_basic',
|
||||
'BincountModule_basic',
|
||||
'BincountStaticSizeModule_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
|
||||
'BoolFloatConstantModule_basic',
|
||||
'BoolIntConstantModule_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
|
||||
'ContainsIntList_False',
|
||||
'ContainsIntList_True',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
|
||||
'AllBoolFalseModule_basic',
|
||||
'AllBoolTrueModule_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
|
||||
'AnyBoolFalseModule_basic',
|
||||
'AnyBoolTrueModule_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
|
||||
'SqrtIntConstantModule_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||
'AtenIntBoolOpConstFalseModule_basic',
|
||||
'AtenIntBoolOpConstTrueModule_basic',
|
||||
'IntFloatModule_basic',
|
||||
'PowIntFloatModule_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
|
||||
'LenStrModule_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
|
||||
'NumelModule_basic',
|
||||
'NumelZeroRankModule_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
|
||||
'PrimMaxIntModule_basic',
|
||||
|
||||
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
|
||||
'PrimMinIntModule_basic',
|
||||
|
||||
# empty graph
|
||||
# START tests failing due to: empty graph in dynamo
|
||||
'IsFloatingPointFloat_True',
|
||||
'IsFloatingPointInt_False',
|
||||
'TorchPrimLoopForLikeModule_basic',
|
||||
'TorchPrimLoopWhileLikeModule_basic',
|
||||
# END tests failing due to: empty graph in dynamo
|
||||
|
||||
# Forming aten.view_as_real and aten.view_as_imag instead of aten.real and aten.imag op.
|
||||
# Complex ops
|
||||
# ERROR due to: backend never runs because of empty frame
|
||||
'ConstantBoolParameterModule_basic',
|
||||
|
||||
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"AddCDivModule_basic",
|
||||
"ElementwiseMulScalarModule_basic",
|
||||
"ElementwiseMulScalarModule_float",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"UpSampleNearest2d_basic",
|
||||
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
|
||||
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"MobilenetV3Module_basic",
|
||||
"NativeBatchNorm1DModule_basic",
|
||||
"NativeBatchNorm2DModule_basic",
|
||||
"NativeBatchNorm3DModule_basic",
|
||||
"NativeBatchNormNoneWeightModule_basic",
|
||||
"NativeGroupNormModule_basic",
|
||||
"ResNet18Module_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
|
||||
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
"HBC_basic",
|
||||
|
||||
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"ElementwiseDivScalarModule_basic",
|
||||
|
||||
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
"ElementwiseMulScalarModule_int",
|
||||
|
||||
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
|
||||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||
"ElementwiseDivRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||
|
||||
# ERROR: Exception: Unsupported op: get_attr
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"TensorFloatModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
|
||||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.randn.generator
|
||||
"RandnGeneratorF64Module_basic",
|
||||
"RandnGeneratorModule_basic",
|
||||
|
||||
# START tests failing due to: complex floating point ops
|
||||
"AtenComplexImagModule_basic",
|
||||
"AtenComplexRealModule_basic",
|
||||
# END tests failing due to: complex floating point ops
|
||||
}
|
||||
|
||||
# See https://github.com/llvm/torch-mlir/issues/2050
|
||||
TORCHDYNAMO_CRASHING_SET = {
|
||||
# No upstream decompositions.
|
||||
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
||||
# See also: https://github.com/pytorch/torchdynamo/issues/327
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
# https://github.com/pytorch/pytorch/issues/100838
|
||||
"BaddbmmDifferentDtypesModule_basic",
|
||||
"FullModuleInt3D_basic",
|
||||
"ThresholdBackward1dIntModule_basic",
|
||||
"ThresholdBackward2dIntModule_basic",
|
||||
"ThresholdBackward3dIntModule_basic",
|
||||
# See https://github.com/llvm/torch-mlir/issues/2050
|
||||
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
|
@ -226,11 +288,7 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SliceModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||
"SliceSizeTwoStepModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||
"TModuleRank2_basic",
|
||||
"ToCopyModule_basic",
|
||||
"TransposeIntModule_basic",
|
||||
|
|
|
@ -51,6 +51,7 @@ if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
|||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES
|
||||
__init__.py
|
||||
_dynamo_fx_importer.py
|
||||
compiler_utils.py
|
||||
dynamo.py
|
||||
)
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch._dynamo as dynamo
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_compilation_context, set_model_name
|
||||
|
||||
from torch_mlir.compiler_utils import TorchMlirCompilerError
|
||||
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
||||
from torch_mlir_e2e_test.configs.torchdynamo import jit
|
||||
|
||||
|
||||
@make_boxed_compiler
|
||||
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
print(gm.graph)
|
||||
*_, model_name, nth_graph = get_aot_compilation_context()
|
||||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||
print(mlir_module.operation.get_asm(enable_debug_info=True))
|
||||
return gm
|
||||
|
||||
|
||||
my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend)
|
||||
|
||||
|
||||
# CHECK: module attributes {torch.debug_module_name = "basic"} {
|
||||
# CHECK-NEXT: func.func @basic(%[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
|
||||
# CHECK-NEXT: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> loc(#[[LOC:.*]])
|
||||
# CHECK-NEXT: return %[[TANH]] : !torch.vtensor<[3,4],f32>
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: #[[LOC]] = loc("{{.*}}/dynamo_fx_importer/basic.py":{{[0-9]+}}:{{[0-9]+}})
|
||||
@dynamo.optimize(my_backend)
|
||||
def basic(x):
|
||||
return torch.tanh(x)
|
||||
|
||||
|
||||
set_model_name("basic")
|
||||
basic(torch.randn(3, 4))
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @literals_list_device_int_none_dtype() -> !torch.vtensor<[3,4],f16> {
|
||||
# CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
# CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
# CHECK: %[[INT5:.*]] = torch.constant.int 5
|
||||
# CHECK: %[[NONE0:.*]] = torch.constant.none
|
||||
# CHECK: %[[DEVICE_CPU:.*]] = torch.constant.device "cpu"
|
||||
# CHECK: %[[NONE1:.*]] = torch.constant.none
|
||||
# CHECK: %[[RANDN:.*]] = torch.aten.randn %[[LIST]], %[[INT5]], %[[NONE0]], %[[DEVICE_CPU]], %[[NONE1]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.none -> !torch.vtensor<[3,4],f16>
|
||||
# CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16>
|
||||
@dynamo.optimize(my_backend)
|
||||
def literals_list_device_int_none_dtype():
|
||||
return torch.ops.aten.randn([3, 4],
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float16)
|
||||
|
||||
|
||||
set_model_name("literals_list_device_int_none_dtype")
|
||||
literals_list_device_int_none_dtype()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @literals_bool(
|
||||
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
|
||||
# CHECK: %[[NONE0:.*]] = torch.constant.none
|
||||
# CHECK: %[[NONE1:.*]] = torch.constant.none
|
||||
# CHECK: %[[NONE2:.*]] = torch.constant.none
|
||||
# CHECK: %[[BOOL_FALSE:.*]] = torch.constant.bool false
|
||||
# CHECK: %[[NONE3:.*]] = torch.constant.none
|
||||
# CHECK: %[[EMPTY_LIKE:.*]] = torch.aten.empty_like %[[ARG0]], %[[NONE0]], %[[NONE1]], %[[NONE2]], %[[BOOL_FALSE]], %[[NONE3]] : !torch.vtensor<[3,4],f32>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: return %[[EMPTY_LIKE]] : !torch.vtensor<[3,4],f32>
|
||||
@dynamo.optimize(my_backend)
|
||||
def literals_bool(x):
|
||||
return torch.ops.aten.empty_like(x, pin_memory=False)
|
||||
|
||||
|
||||
set_model_name("literals_bool")
|
||||
literals_bool(torch.randn(3, 4))
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @literals_float(
|
||||
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
|
||||
# CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
# CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
# CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
# CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[ARG0]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: return %[[UNIFORM]] : !torch.vtensor<[3,4],f32>
|
||||
@dynamo.optimize(my_backend)
|
||||
def literals_float(x):
|
||||
return torch.ops.aten.uniform(x, 0.0, 1.0)
|
||||
|
||||
|
||||
set_model_name("literals_float")
|
||||
literals_float(torch.randn(3, 4))
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @literals_str(
|
||||
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
|
||||
# CHECK: %[[STR_TANH:.*]] = torch.constant.str "tanh"
|
||||
# CHECK: %[[GELU:.*]] = torch.aten.gelu %[[ARG0]], %[[STR_TANH]] : !torch.vtensor<[3,4],f32>, !torch.str -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: return %[[GELU]] : !torch.vtensor<[3,4],f32>
|
||||
@dynamo.optimize(my_backend)
|
||||
def literals_str(x):
|
||||
return torch.ops.aten.gelu(x, approximate="tanh")
|
||||
|
||||
|
||||
set_model_name("literals_str")
|
||||
literals_str(torch.randn(3, 4))
|
|
@ -12,8 +12,8 @@ import tempfile
|
|||
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from .compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
|
||||
|
@ -248,6 +248,64 @@ BACKEND_LEGAL_OPS = {
|
|||
}
|
||||
|
||||
|
||||
def _canon_extra_library(extra_library):
|
||||
extra_library_file_name = ""
|
||||
if len(extra_library) != 0:
|
||||
extra_library_dict = {}
|
||||
for library_func in extra_library:
|
||||
extra_library_dict[library_func.__name__] = library_func
|
||||
mlir_library = generate_library(extra_library_dict)
|
||||
|
||||
extra_library_file_name = \
|
||||
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
|
||||
with open(extra_library_file_name, "w") as f:
|
||||
f.write(mlir_library)
|
||||
return extra_library_file_name
|
||||
|
||||
|
||||
def _lower_mlir_module(verbose, output_type, module):
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("Torch Backend IR")
|
||||
print(module)
|
||||
|
||||
if output_type == OutputType.TORCH:
|
||||
return module
|
||||
|
||||
if output_type == OutputType.TOSA:
|
||||
run_pipeline_with_repro_report(
|
||||
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("TOSA Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
|
||||
if output_type == OutputType.LINALG_ON_TENSORS:
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("LINALG Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
|
||||
elif output_type == OutputType.STABLEHLO:
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> StableHLO Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("StableHLO Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
raise Exception(f"Unknown OutputType: {output_type}")
|
||||
|
||||
|
||||
def compile(model: torch.nn.Module,
|
||||
example_args: _example_args,
|
||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
|
@ -290,18 +348,7 @@ def compile(model: torch.nn.Module,
|
|||
An MLIR module that contains the converted model in the specified
|
||||
output type.
|
||||
"""
|
||||
extra_library_file_name = ""
|
||||
if len(extra_library) != 0:
|
||||
extra_library_dict = {}
|
||||
for library_func in extra_library:
|
||||
extra_library_dict[library_func.__name__] = library_func
|
||||
mlir_library = generate_library(extra_library_dict)
|
||||
|
||||
extra_library_file_name = \
|
||||
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
|
||||
with open(extra_library_file_name, "w") as f:
|
||||
f.write(mlir_library)
|
||||
|
||||
extra_library_file_name = _canon_extra_library(extra_library)
|
||||
output_type = OutputType.get(output_type)
|
||||
example_args = ExampleArgs.get(example_args)
|
||||
if ignore_traced_shapes and not use_tracing:
|
||||
|
@ -394,44 +441,4 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
"Lowering TorchScript IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("Torch Backend IR")
|
||||
print(mb.module)
|
||||
|
||||
if output_type == OutputType.TORCH:
|
||||
return mb.module
|
||||
|
||||
if output_type == OutputType.TOSA:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"builtin.module(torch-backend-to-tosa-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("TOSA Backend IR")
|
||||
print(mb.module)
|
||||
return mb.module
|
||||
|
||||
if output_type == OutputType.LINALG_ON_TENSORS:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("LINALG Backend IR")
|
||||
print(mb.module)
|
||||
return mb.module
|
||||
|
||||
elif output_type == OutputType.STABLEHLO:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> StableHLO Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("StableHLO Backend IR")
|
||||
print(mb.module)
|
||||
return mb.module
|
||||
raise Exception(f"Unknown OutputType: {output_type}")
|
||||
return _lower_mlir_module(verbose, output_type, mb.module)
|
||||
|
|
|
@ -0,0 +1,442 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
import pdb
|
||||
# This file implements a pure-Python importer from a restricted subset of
|
||||
# FX IR into MLIR.
|
||||
#
|
||||
# As described in the
|
||||
# [long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#refactoring-the-frontend),
|
||||
# the goal is to import directly in the Torch-MLIR backend contract by
|
||||
# using the available PyTorch infra for doing functionalization,
|
||||
# shape inference, etc. Thus, this importer imports a very specific subset
|
||||
# of the possible FX IR that is co-designed with the PyTorch infra that produces
|
||||
# the FX graph -- see the `torch_mlir.dynamo` module for that, and see the
|
||||
# `_verify_fx_graph_conforms_to_subset` function for the operational definition.
|
||||
#
|
||||
# In fact, because of the generality of FX IR (e.g. the use of raw Python
|
||||
# callables as node.target), there is really no well-defined way to implement a
|
||||
# general FX -> MLIR importer. Reuse or extension of this code for other
|
||||
# FX -> MLIR use cases should be done carefully, and likely will involve
|
||||
# introducing new concepts or abstractions into the import process.
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import operator
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
import torch_mlir.ir as ir
|
||||
import torch_mlir.dialects.func as func_dialect
|
||||
import torch_mlir.dialects.torch as torch_dialect
|
||||
|
||||
|
||||
def _is_valid_meta_val(val):
|
||||
# We currently allow only FakeTensor's or lists of FakeTensor's
|
||||
# as meta['val']. However, this can potentially change also hold a SymInt
|
||||
# in the future. See:
|
||||
# https://github.com/pytorch/pytorch/issues/90839#issuecomment-1352856661
|
||||
if isinstance(val, torch._subclasses.FakeTensor):
|
||||
return True
|
||||
if isinstance(val, (tuple, list)):
|
||||
return all(isinstance(x, torch._subclasses.FakeTensor) for x in val)
|
||||
return False
|
||||
|
||||
|
||||
def _verify_fx_graph_conforms_to_subset(g: torch.fx.Graph):
|
||||
# TODO: Report errors with source locations if possible.
|
||||
def _check_meta_val(node):
|
||||
if "val" not in node.meta:
|
||||
raise Exception(f"Unsupported: missing node.meta['val']: {node}")
|
||||
if not _is_valid_meta_val(node.meta["val"]):
|
||||
raise Exception(
|
||||
f"Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: {node}; {node.meta['val']}"
|
||||
)
|
||||
|
||||
for node in g.nodes:
|
||||
if node.op not in ("placeholder", "call_function", "output"):
|
||||
raise Exception(f"Unsupported op: {node.op}")
|
||||
if node.op == "placeholder":
|
||||
_check_meta_val(node)
|
||||
if node.op == "call_function":
|
||||
_check_meta_val(node)
|
||||
# We only support OpOverload for computations because the `torch`
|
||||
# dialect ops model the full qualified op name, including overload.
|
||||
# We also support operator.getitem because that is how multiple
|
||||
# results are modeled.
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
for type_ in (r.type for r in node.target._schema.returns):
|
||||
if isinstance(type_, torch.TensorType):
|
||||
continue
|
||||
raise Exception(
|
||||
f"Unsupported: return type {type_} in schema for {node.target}"
|
||||
)
|
||||
if len(node.args) != len(node.target._schema.arguments):
|
||||
assert len(node.args) < len(node.target._schema.arguments)
|
||||
for i, argument in enumerate(
|
||||
node.target._schema.arguments[len(node.args):]):
|
||||
if not argument.has_default_value():
|
||||
raise Exception(
|
||||
f"Unsupported: missing default value for argument {i} in schema for {node.target}"
|
||||
)
|
||||
continue
|
||||
if node.target is operator.getitem:
|
||||
continue
|
||||
raise Exception(f"Unsupported call_function target: {node.target}")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Type import
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def _torch_type_to_mlir_type_string(t: torch.Type) -> str:
|
||||
# This is weird -- for Node's, since they are untyped, we use the
|
||||
# node.meta['val'] to get the type (which is a tensor type with sizes and
|
||||
# dtype).
|
||||
# But for things that are associated with a schema, we use the schema to get
|
||||
# the type. This creates problems for things like a list<tensor> because
|
||||
# then we don't have sizes or dtypes available.
|
||||
if isinstance(t, torch.ListType):
|
||||
return f"list<{_torch_type_to_mlir_type_string(t.getElementType())}>"
|
||||
if isinstance(t, torch.BoolType):
|
||||
return "bool"
|
||||
if isinstance(t, torch.IntType):
|
||||
return "int"
|
||||
if isinstance(t, torch.FloatType):
|
||||
return "float"
|
||||
if isinstance(t, torch.StringType):
|
||||
return "string"
|
||||
if isinstance(t, torch.TensorType):
|
||||
return "vtensor"
|
||||
if isinstance(t, torch.OptionalType):
|
||||
return f"optional<{_torch_type_to_mlir_type_string(t.getElementType())}>"
|
||||
raise Exception(f"Unsupported type: {t}")
|
||||
|
||||
|
||||
def _torch_type_to_mlir_type(t: torch.Type):
|
||||
return ir.Type.parse(f"!torch.{_torch_type_to_mlir_type_string(t)}")
|
||||
|
||||
|
||||
def _convert_dtype_to_mlir_type(dtype: torch.dtype) -> str:
|
||||
# See the table in TorchTypes.td:AnyTorchTensorType's documentation.
|
||||
if dtype == torch.float16:
|
||||
return "f16"
|
||||
if dtype == torch.bfloat16:
|
||||
return "bf16"
|
||||
if dtype == torch.float32:
|
||||
return "f32"
|
||||
if dtype == torch.float64:
|
||||
return "f64"
|
||||
if dtype == torch.uint8:
|
||||
return "ui8"
|
||||
if dtype == torch.int8:
|
||||
return "si8"
|
||||
if dtype == torch.int16:
|
||||
return "si16"
|
||||
if dtype == torch.int32:
|
||||
return "si32"
|
||||
if dtype == torch.int64:
|
||||
return "si64"
|
||||
if dtype == torch.bool:
|
||||
return "i1"
|
||||
if dtype == torch.qint8:
|
||||
return "!torch.qint8"
|
||||
if dtype == torch.quint8:
|
||||
return "!torch.quint8"
|
||||
if dtype == torch.complex64:
|
||||
return "complex<f64>"
|
||||
|
||||
|
||||
raise Exception(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def _import_fake_tensor_as_mlir_type(
|
||||
fake_tensor: torch._subclasses.FakeTensor) -> ir.Type:
|
||||
# TODO: Find story for how to get dynamically shaped tensors here.
|
||||
shape = ",".join(str(d) for d in fake_tensor.shape)
|
||||
dtype = _convert_dtype_to_mlir_type(fake_tensor.dtype)
|
||||
return ir.Type.parse(f"!torch.vtensor<[{shape}],{dtype}>")
|
||||
|
||||
|
||||
def _mlir_types_for_node(node: torch.fx.Node) -> ir.Type:
|
||||
if isinstance(node.meta["val"], (tuple, list)):
|
||||
return [_import_fake_tensor_as_mlir_type(v) for v in node.meta["val"]]
|
||||
return [_import_fake_tensor_as_mlir_type(node.meta["val"])]
|
||||
|
||||
|
||||
def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType:
|
||||
input_types = []
|
||||
for node in g.nodes:
|
||||
if node.op == "placeholder":
|
||||
input_types.append(_mlir_types_for_node(node)[0])
|
||||
if node.op == "output":
|
||||
# TODO(DNS): Test this or add verifier that it can't happen.
|
||||
result_types = torch.fx.map_arg(
|
||||
node.args[0], lambda n: _mlir_types_for_node(n)[0])
|
||||
# Note: We import directly to the backend contract -- multiple results
|
||||
# are modeled with func.func native multiple results rather than as a
|
||||
# singleton value / tuple.
|
||||
return ir.FunctionType.get(input_types, result_types)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# FX Graph import
|
||||
# ==============================================================================
|
||||
|
||||
DTYPE_TO_INT = {
|
||||
# TODO(DNS): Fill in from AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
|
||||
torch.uint8:
|
||||
0,
|
||||
torch.int8:
|
||||
1,
|
||||
torch.int16:
|
||||
2,
|
||||
torch.int32:
|
||||
3,
|
||||
torch.int64:
|
||||
4,
|
||||
torch.float16:
|
||||
5,
|
||||
torch.float32:
|
||||
6,
|
||||
torch.float64:
|
||||
7,
|
||||
# torch.complex_half 8
|
||||
torch.complex32:
|
||||
9,
|
||||
torch.complex64:
|
||||
10,
|
||||
torch.bool:
|
||||
11,
|
||||
torch.qint8:
|
||||
12,
|
||||
torch.quint8:
|
||||
13,
|
||||
# torch.qint32 14
|
||||
torch.bfloat16:
|
||||
15,
|
||||
}
|
||||
|
||||
MEMORY_FORMAT_TO_INT = {
|
||||
# https://github.com/pytorch/pytorch/c10/core/MemoryFormat.h#L28
|
||||
torch.contiguous_format:
|
||||
0,
|
||||
torch.preserve_format:
|
||||
1,
|
||||
torch.channels_last:
|
||||
2,
|
||||
torch.channels_last_3d:
|
||||
3,
|
||||
}
|
||||
|
||||
LAYOUT_TO_INT = {
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_layouts.cpp
|
||||
torch.strided:
|
||||
0,
|
||||
torch.sparse_coo:
|
||||
1,
|
||||
torch.sparse_csr:
|
||||
2,
|
||||
torch.sparse_csc:
|
||||
3,
|
||||
torch.sparse_bsr:
|
||||
4,
|
||||
torch.sparse_bsc:
|
||||
5,
|
||||
}
|
||||
|
||||
|
||||
def _mlir_location_for_node(node: torch.fx.Node) -> ir.Location:
|
||||
stack_trace = node.stack_trace
|
||||
if stack_trace is None:
|
||||
return ir.Location.unknown()
|
||||
# TODO: Avoid needing to regex match this.
|
||||
# https://github.com/pytorch/pytorch/issues/91000
|
||||
m = re.search(r"""File "([^"]+)", line ([0-9]+),""", node.stack_trace)
|
||||
filename, line = m.group(1), int(m.group(2))
|
||||
return ir.Location.file(filename, line, col=0)
|
||||
|
||||
|
||||
class _FXGraphImporter:
|
||||
|
||||
def __init__(self, g: torch.fx.Graph, func_name: str):
|
||||
self._g = g
|
||||
self._func_name = func_name
|
||||
# For each node, we track a mapping to MLIR Value's.
|
||||
# Technically all Node's have a single output (which can be a tuple of
|
||||
# values in case of multiple returns), but we treat them as having
|
||||
# multiple returns directly. This matches how the Node's
|
||||
# node.meta['val'] is set up, since it contains a list with multiple
|
||||
# FakeTensor's in case of a tuple return with multiple elements.
|
||||
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
||||
self._module = ir.Module.create(ir.Location.unknown())
|
||||
self._module.operation.attributes[
|
||||
"torch.debug_module_name"] = ir.StringAttr.get(func_name)
|
||||
function_type = _extract_function_type_from_graph(g)
|
||||
func = func_dialect.FuncOp(
|
||||
func_name,
|
||||
function_type,
|
||||
loc=ir.Location.unknown(), # TODO: Can we do better?
|
||||
ip=ir.InsertionPoint(self._module.body),
|
||||
)
|
||||
self._body_block = ir.Block.create_at_start(func.body,
|
||||
function_type.inputs)
|
||||
|
||||
def import_graph(self) -> ir.Module:
|
||||
with ir.InsertionPoint(self._body_block):
|
||||
num_placeholders_seen = 0
|
||||
for node in self._g.nodes:
|
||||
with _mlir_location_for_node(node):
|
||||
if node.op == "placeholder":
|
||||
self._env[(
|
||||
node, 0
|
||||
)] = self._body_block.arguments[num_placeholders_seen]
|
||||
num_placeholders_seen += 1
|
||||
if node.op == "call_function":
|
||||
if node.target is operator.getitem:
|
||||
self._env[(node, 0)] = self._env[(node.args[0],
|
||||
node.args[1])]
|
||||
else:
|
||||
self._import_op_overload_call(node)
|
||||
if node.op == "output":
|
||||
# Note that the output node is a singleton tuple holding
|
||||
# a tuple of return values (without the single-element special
|
||||
# case)
|
||||
# DNS: Test or verify no literals as results.
|
||||
operands = [
|
||||
self._import_argument(arg) for arg in node.args[0]
|
||||
]
|
||||
func_dialect.ReturnOp(operands)
|
||||
return self._module
|
||||
|
||||
def _import_op_overload_call(self, node: torch.fx.Node):
|
||||
assert node.op == "call_function"
|
||||
assert isinstance(node.target, torch._ops.OpOverload)
|
||||
schema = node.target._schema
|
||||
|
||||
# Extract the `torch` dialect op name.
|
||||
namespace, _, unqualified_name = schema.name.partition("::")
|
||||
mlir_op_name = f"torch.{namespace}.{unqualified_name}"
|
||||
if schema.overload_name != "":
|
||||
mlir_op_name += f".{schema.overload_name}"
|
||||
|
||||
# DNS: Unregistered ops
|
||||
assert ir.Context.current.is_registered_operation(
|
||||
mlir_op_name), f"Unregistered operation: {mlir_op_name}"
|
||||
|
||||
# Construct the Operation.
|
||||
result_types = _mlir_types_for_node(node)
|
||||
operands = []
|
||||
# `schema.arguments` is a bit confusing in this context, since
|
||||
# `Argument` is the term that FX uses analogous to mlir "Value". It is
|
||||
# more precise to call them "formal parameters".
|
||||
for i, parameter in enumerate(node.target._schema.arguments):
|
||||
if parameter.kwarg_only and parameter.name in node.kwargs:
|
||||
arg = node.kwargs[parameter.name]
|
||||
elif i < len(node.args):
|
||||
arg = node.args[i]
|
||||
else:
|
||||
arg = parameter.default_value
|
||||
operands.append(self._import_argument(arg, parameter.type))
|
||||
operation = ir.Operation.create(
|
||||
mlir_op_name,
|
||||
results=result_types,
|
||||
operands=operands,
|
||||
)
|
||||
for i, value in enumerate(operation.results):
|
||||
self._env[(node, i)] = value
|
||||
|
||||
def _import_argument(self,
|
||||
arg: torch.fx.node.Argument,
|
||||
expected_type_for_literal=None) -> ir.Value:
|
||||
"""Import an FX `Argument`, which is analogous to an MLIR `Value`.
|
||||
|
||||
Args:
|
||||
arg: The FX `Argument` to import.
|
||||
expected_type_for_literal: If `arg` is a literal (such as a Python
|
||||
`int` or `float` object), this is the expected JIT IR type. This
|
||||
allows disambiguating certain cases, such as importing an optional
|
||||
type.
|
||||
Returns:
|
||||
The imported MLIR `Value`.
|
||||
"""
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
return self._env[(arg, 0)]
|
||||
assert expected_type_for_literal is not None
|
||||
return self._import_literal(arg, expected_type_for_literal)
|
||||
|
||||
def _import_literal(self, arg: torch.fx.node.Argument,
|
||||
expected_type) -> ir.Value:
|
||||
if arg is None:
|
||||
return torch_dialect.ConstantNoneOp().result
|
||||
if isinstance(expected_type, torch.OptionalType):
|
||||
return self._import_argument(arg, expected_type.getElementType())
|
||||
if isinstance(arg, bool):
|
||||
return torch_dialect.ConstantBoolOp(
|
||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)).result
|
||||
if isinstance(arg, int):
|
||||
return torch_dialect.ConstantIntOp(
|
||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(64),
|
||||
arg)).result
|
||||
if isinstance(arg, float):
|
||||
return torch_dialect.ConstantFloatOp(
|
||||
ir.FloatAttr.get_f64(arg)).result
|
||||
if isinstance(arg, str):
|
||||
return torch_dialect.ConstantStrOp(ir.StringAttr.get(arg)).result
|
||||
if isinstance(arg, torch.dtype):
|
||||
assert isinstance(expected_type, torch.IntType)
|
||||
return self._import_argument(DTYPE_TO_INT[arg], expected_type)
|
||||
if isinstance(arg, torch.device):
|
||||
# TODO(DNS): Device index? arg.index
|
||||
return torch_dialect.ConstantDeviceOp(ir.StringAttr.get(
|
||||
arg.type)).result
|
||||
if isinstance(arg, torch.memory_format):
|
||||
assert isinstance(expected_type, torch.IntType)
|
||||
return self._import_argument(MEMORY_FORMAT_TO_INT[arg],
|
||||
expected_type)
|
||||
if isinstance(arg, torch.layout):
|
||||
assert isinstance(expected_type, torch.IntType)
|
||||
return self._import_argument(LAYOUT_TO_INT[arg], expected_type)
|
||||
if isinstance(arg, list):
|
||||
assert isinstance(expected_type, torch.ListType)
|
||||
element_type = expected_type.getElementType()
|
||||
if isinstance(element_type, torch.TensorType):
|
||||
assert all(
|
||||
torch.fx.node.map_aggregate(
|
||||
arg, lambda a: _is_valid_meta_val(a.meta.get("val"))))
|
||||
els = [self._env[e, 0] for e in arg]
|
||||
|
||||
else:
|
||||
element_type = _torch_type_to_mlir_type(element_type)
|
||||
els = [
|
||||
self._import_argument(e, element_type) for e in arg
|
||||
]
|
||||
|
||||
# import pydevd_pycharm
|
||||
# pydevd_pycharm.settrace('localhost', port=8888, stdoutToServer=True, stderrToServer=True)
|
||||
return torch_dialect.PrimListConstructOp(
|
||||
_torch_type_to_mlir_type(expected_type),
|
||||
els,
|
||||
).result
|
||||
raise Exception(f"Unsupported literal: {arg}")
|
||||
|
||||
|
||||
def import_fx_graph_as_func(g: torch.fx.Graph, func_name: str) -> ir.Module:
|
||||
"""Imports the given FX graph as a function in a new MLIR module.
|
||||
|
||||
Args:
|
||||
g: The FX graph to import.
|
||||
func_name: The sym_name of the `func.func` to import the graph into.
|
||||
Returns:
|
||||
A new MLIR module containing the imported function.
|
||||
"""
|
||||
# Note that this function imports a fx.Graph instead of an fx.GraphModule.
|
||||
# The reason is that the supported subset only involves stateless
|
||||
# fx.Graph's, so the state held on the fx.GraphModule is not necessary.
|
||||
_verify_fx_graph_conforms_to_subset(g)
|
||||
with ir.Context() as context:
|
||||
torch_dialect.register_dialect(context)
|
||||
return _FXGraphImporter(g, func_name).import_graph()
|
|
@ -2,96 +2,149 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
from typing import List, Union, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
import torch_mlir
|
||||
import torch_mlir.dynamo
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._functorch.aot_autograd import (
|
||||
make_boxed_compiler,
|
||||
get_aot_compilation_context,
|
||||
set_model_name,
|
||||
)
|
||||
|
||||
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
||||
from torch_mlir.dynamo import _get_decomposition_table
|
||||
from torch_mlir import (
|
||||
_example_args,
|
||||
OutputType,
|
||||
BACKEND_LEGAL_OPS,
|
||||
run_pipeline_with_repro_report,
|
||||
_lower_mlir_module,
|
||||
_canon_extra_library,
|
||||
)
|
||||
from torch_mlir_e2e_test.configs.utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
)
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
def refine_result_type(_result):
|
||||
if isinstance(_result, tuple):
|
||||
return tuple(refine_result_type(x) for x in _result)
|
||||
elif isinstance(_result, np.ndarray):
|
||||
return torch.from_numpy(_result)
|
||||
elif isinstance(_result, (bool, int, float)):
|
||||
return _result
|
||||
else:
|
||||
raise ValueError(f"Unhandled return type {type(_result)}")
|
||||
|
||||
|
||||
def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_graph.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert len(node.args) == 1, "Output node must have a single argument"
|
||||
assert len(
|
||||
node.args) == 1, "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if node_arg != ():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
# Use the LinalgOnTensors backend, since it is the most complete.
|
||||
# In theory we could mix and match TorchDynamo with the other backends,
|
||||
# since they all lower through the same backend contract.
|
||||
# For now, testing-wise, it doesn't make sense to test those configurations.
|
||||
# We really just want to check the TorchDynamo frontend.
|
||||
#
|
||||
# Longer-term we will need to do something more sophisticated here.
|
||||
# As per the long-term roadmap:
|
||||
# https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#refactoring-the-frontend
|
||||
# We will eventually have a configuration that uses new PyTorch infra and
|
||||
# skips the entire "frontend" part. We currently don't have any code
|
||||
# for that right now since it is still very early stages, but eventually
|
||||
# this Config should test that path (and maybe the current behavior can
|
||||
# be moved to a `legacy_frontend_via_torchdynamo` config).
|
||||
def jit(
|
||||
model: torch.nn.Module,
|
||||
example_args: _example_args,
|
||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
extra_library=None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
if extra_library is None:
|
||||
extra_library = []
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
# Torch-MLIR does not support returning an empty tuple. The reason is
|
||||
# that both returning an empty tuple and returning `None` results in MLIR
|
||||
# functions that have as a return type `()`. In other words, there is no
|
||||
# way of differentiating between the two.
|
||||
assert not _returns_empty_tuple(fx_graph), "encountered graph that does not return anything"
|
||||
mlir_module = None
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors")
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(mlir_module)
|
||||
loaded = backend.load(compiled)
|
||||
extra_library_file_name = _canon_extra_library(extra_library)
|
||||
output_type = OutputType.get(output_type)
|
||||
if backend_legal_ops is not None:
|
||||
if output_type != OutputType.TORCH:
|
||||
raise Exception("`backend_legal_ops` is only valid with the "
|
||||
"`torch` output type")
|
||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
def refine_result_type(_result):
|
||||
if isinstance(_result, tuple):
|
||||
return tuple(refine_result_type(x) for x in _result)
|
||||
elif isinstance(_result, numpy.ndarray):
|
||||
return torch.from_numpy(_result)
|
||||
elif isinstance(_result, (bool, int, float)):
|
||||
return _result
|
||||
else:
|
||||
raise ValueError(f"Unhandled return type {type(_result)}")
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = loaded.forward(*inputs)
|
||||
return refine_result_type(result)
|
||||
return compiled_callable
|
||||
@make_boxed_compiler
|
||||
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
|
||||
_example_inputs: List[torch.Tensor]):
|
||||
# Torch-MLIR does not support returning an empty tuple. The reason is
|
||||
# that both returning an empty tuple and returning `None` results in MLIR
|
||||
# functions that have as a return type `()`. In other words, there is no
|
||||
# way of differentiating between the two.
|
||||
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
|
||||
|
||||
nonlocal mlir_module
|
||||
*_, model_name, nth_graph = get_aot_compilation_context()
|
||||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||
return gm
|
||||
|
||||
my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend,
|
||||
decompositions=_get_decomposition_table)
|
||||
|
||||
with torch.no_grad():
|
||||
set_model_name(model.__class__.__name__)
|
||||
torch._dynamo.reset()
|
||||
dynamo_f = dynamo.optimize(my_backend, nopython=True)(
|
||||
lambda method, *inputs: method(*inputs))
|
||||
dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]),
|
||||
*example_args)
|
||||
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
|
||||
" extra-library=" + extra_library_file_name + "}")
|
||||
assert mlir_module is not None
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
# f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
|
||||
f"builtin.module(torch-lower-to-backend-contract)",
|
||||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
return _lower_mlir_module(verbose, output_type, mlir_module)
|
||||
|
||||
|
||||
class TorchDynamoTestConfig(TestConfig):
|
||||
"""TestConfig that runs the torch.nn.Module with TorchDynamo"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, backend):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
return program
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
def item_symbol_that_clones_inputs(*inputs):
|
||||
cloned_inputs = [x.clone() for x in inputs]
|
||||
result = getattr(artifact, item.symbol)(*cloned_inputs)
|
||||
return result
|
||||
# TODO: Deepcopy the torch.nn.Module, so that if the program is
|
||||
# stateful then it does not mutate the original compiled program.
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
f = lambda method, *inputs: method(*inputs)
|
||||
torch._dynamo.reset()
|
||||
dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend, nopython=True)(f)
|
||||
output = dynamo_f(item_symbol_that_clones_inputs, *item.inputs)
|
||||
module = jit(artifact,
|
||||
item.inputs,
|
||||
output_type="linalg-on-tensors")
|
||||
module = self.backend.compile(module)
|
||||
backend_module = self.backend.load(module)
|
||||
params = {
|
||||
**dict(artifact.named_parameters(remove_duplicate=False)),
|
||||
**dict(artifact.named_buffers(remove_duplicate=False)),
|
||||
}
|
||||
params_flat, params_spec = pytree.tree_flatten(params)
|
||||
params_flat = list(params_flat)
|
||||
with torch.no_grad():
|
||||
numpy_inputs = recursively_convert_to_numpy(params_flat +
|
||||
item.inputs)
|
||||
outputs = getattr(backend_module,
|
||||
artifact.__class__.__name__)(*numpy_inputs)
|
||||
output = refine_result_type(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
|
|
|
@ -818,7 +818,7 @@ def GatherNegativeDimModule_basic(module, tu: TestUtils):
|
|||
class GatherRandomIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -839,7 +839,7 @@ def GatherRandomIndexModule_basic(module, tu: TestUtils):
|
|||
class Gather2DInputModdule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -1914,7 +1914,7 @@ class TensorLiteralModule(torch.nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.t = torch.randint(-5, 5, (2, 3))
|
||||
self.register_buffer("t", torch.randint(-5, 5, (2, 3)))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
|
@ -1937,7 +1937,7 @@ class TensorOpaqueLiteralModule(torch.nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.t = torch.randint(-5, 5, (256, 1024))
|
||||
self.register_buffer("t", torch.randint(-5, 5, (256, 1024)))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
|
|
Loading…
Reference in New Issue