Pure-Python FX importer. (#2098)

Co-authored-by: Sean Silva <silvasean@google.com>
pull/2119/head snapshot-20230512.836
Maksim Levental 2023-05-12 00:46:33 -05:00 committed by GitHub
parent e161f2511a
commit c3cd7471b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 836 additions and 159 deletions

View File

@ -106,7 +106,7 @@ def main():
xfail_set = LTC_XFAIL_SET
crashing_set = set()
elif args.config == "torchdynamo":
config = TorchDynamoTestConfig()
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = TORCHDYNAMO_XFAIL_SET
crashing_set = TORCHDYNAMO_CRASHING_SET

View File

@ -56,20 +56,7 @@ TORCHDYNAMO_XFAIL_SET = {
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
# %7 = torch.operator "aten._index_put_impl_.hacked_twin"(%1, %6, %5, %true, %false) : (!torch.tensor<*,f32>, !torch.list<tensor>, !torch.tensor<*,f32>, !torch.bool, !torch.bool) -> !torch.tensor
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl2DIndexModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
# https://github.com/llvm/torch-mlir/issues/1611
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
"Aten_EmbeddingBagExample_basic",
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
"BernoulliFloatModule_basic",
"BernoulliPModule_basic",
@ -77,8 +64,6 @@ TORCHDYNAMO_XFAIL_SET = {
"ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic",
"UniformModule_basic",
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
"TModuleRank1_basic",
# error: unsupported by backend contract: tensor with unknown rank
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
"ElementwisePreluModule_basic",
@ -107,7 +92,7 @@ TORCHDYNAMO_XFAIL_SET = {
'TensorToBoolZeroRank_basic',
'TensorToBool_basic',
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
'AtenSubFloatModule_basic',
'BoolFloatFalseModule_basic',
'BoolFloatTrueModule_basic',
@ -120,8 +105,10 @@ TORCHDYNAMO_XFAIL_SET = {
'SubFloatModule_basic',
'TensorToFloatZeroRank_basic',
'TensorToFloat_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
'AddIntModule_basic',
'AtenIntTensorCharDtypeModule_basic',
'BoolIntFalseModule_basic',
@ -138,78 +125,153 @@ TORCHDYNAMO_XFAIL_SET = {
'TensorToInt_basic',
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
# torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
'SortIntListReverse_basic',
# torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
'SortIntList_basic',
# torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
# START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
'AtenFloatScalarModule_basic',
'AtenIntBoolOpModule_basic',
'OneHotModule_basic',
'QuantizedMLP_basic',
'ScalarImplicitFloatModule_basic',
'ScalarImplicitIntModule_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
# torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
# START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
'BincountMinlengthModule_basic',
'BincountModule_basic',
'BincountStaticSizeModule_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
'BoolFloatConstantModule_basic',
'BoolIntConstantModule_basic',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
'ContainsIntList_False',
'ContainsIntList_True',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
'AllBoolFalseModule_basic',
'AllBoolTrueModule_basic',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
'AnyBoolFalseModule_basic',
'AnyBoolTrueModule_basic',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
'SqrtIntConstantModule_basic',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
'AtenIntBoolOpConstFalseModule_basic',
'AtenIntBoolOpConstTrueModule_basic',
'IntFloatModule_basic',
'PowIntFloatModule_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
'LenStrModule_basic',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
'NumelModule_basic',
'NumelZeroRankModule_basic',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
'PrimMaxIntModule_basic',
# torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
'PrimMinIntModule_basic',
# empty graph
# START tests failing due to: empty graph in dynamo
'IsFloatingPointFloat_True',
'IsFloatingPointInt_False',
'TorchPrimLoopForLikeModule_basic',
'TorchPrimLoopWhileLikeModule_basic',
# END tests failing due to: empty graph in dynamo
# Forming aten.view_as_real and aten.view_as_imag instead of aten.real and aten.imag op.
# Complex ops
# ERROR due to: backend never runs because of empty frame
'ConstantBoolParameterModule_basic',
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"AddCDivModule_basic",
"ElementwiseMulScalarModule_basic",
"ElementwiseMulScalarModule_float",
"NativeGroupNormBackwardModule_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2d_basic",
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
"MobilenetV3Module_basic",
"NativeBatchNorm1DModule_basic",
"NativeBatchNorm2DModule_basic",
"NativeBatchNorm3DModule_basic",
"NativeBatchNormNoneWeightModule_basic",
"NativeGroupNormModule_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"HBC_basic",
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"ElementwiseDivScalarModule_basic",
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
"ElementwiseMulScalarModule_int",
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
"ElementwiseDivRoundingModeFloorModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic",
# ERROR: Exception: Unsupported op: get_attr
"NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.randn.generator
"RandnGeneratorF64Module_basic",
"RandnGeneratorModule_basic",
# START tests failing due to: complex floating point ops
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
# END tests failing due to: complex floating point ops
}
# See https://github.com/llvm/torch-mlir/issues/2050
TORCHDYNAMO_CRASHING_SET = {
# No upstream decompositions.
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
# See also: https://github.com/pytorch/torchdynamo/issues/327
"Aten_EmbeddingBagExample_basic",
# https://github.com/pytorch/pytorch/issues/100838
"BaddbmmDifferentDtypesModule_basic",
"FullModuleInt3D_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward2dIntModule_basic",
"ThresholdBackward3dIntModule_basic",
# See https://github.com/llvm/torch-mlir/issues/2050
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
@ -226,11 +288,7 @@ TORCHDYNAMO_CRASHING_SET = {
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceStaticModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"TModuleRank2_basic",
"ToCopyModule_basic",
"TransposeIntModule_basic",

View File

@ -51,6 +51,7 @@ if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
__init__.py
_dynamo_fx_importer.py
compiler_utils.py
dynamo.py
)

View File

@ -0,0 +1,116 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
from typing import List
import torch
import torch.fx
import torch._dynamo as dynamo
from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_compilation_context, set_model_name
from torch_mlir.compiler_utils import TorchMlirCompilerError
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
from torch_mlir_e2e_test.configs.torchdynamo import jit
@make_boxed_compiler
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
print(gm.graph)
*_, model_name, nth_graph = get_aot_compilation_context()
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
print(mlir_module.operation.get_asm(enable_debug_info=True))
return gm
my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend)
# CHECK: module attributes {torch.debug_module_name = "basic"} {
# CHECK-NEXT: func.func @basic(%[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK-NEXT: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> loc(#[[LOC:.*]])
# CHECK-NEXT: return %[[TANH]] : !torch.vtensor<[3,4],f32>
# CHECK-NEXT: }
# CHECK-NEXT: }
# CHECK-NEXT: #[[LOC]] = loc("{{.*}}/dynamo_fx_importer/basic.py":{{[0-9]+}}:{{[0-9]+}})
@dynamo.optimize(my_backend)
def basic(x):
return torch.tanh(x)
set_model_name("basic")
basic(torch.randn(3, 4))
# CHECK-LABEL: func.func @literals_list_device_int_none_dtype() -> !torch.vtensor<[3,4],f16> {
# CHECK: %[[INT3:.*]] = torch.constant.int 3
# CHECK: %[[INT4:.*]] = torch.constant.int 4
# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
# CHECK: %[[INT5:.*]] = torch.constant.int 5
# CHECK: %[[NONE0:.*]] = torch.constant.none
# CHECK: %[[DEVICE_CPU:.*]] = torch.constant.device "cpu"
# CHECK: %[[NONE1:.*]] = torch.constant.none
# CHECK: %[[RANDN:.*]] = torch.aten.randn %[[LIST]], %[[INT5]], %[[NONE0]], %[[DEVICE_CPU]], %[[NONE1]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.none -> !torch.vtensor<[3,4],f16>
# CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16>
@dynamo.optimize(my_backend)
def literals_list_device_int_none_dtype():
return torch.ops.aten.randn([3, 4],
device=torch.device("cpu"),
dtype=torch.float16)
set_model_name("literals_list_device_int_none_dtype")
literals_list_device_int_none_dtype()
# CHECK-LABEL: func.func @literals_bool(
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK: %[[NONE0:.*]] = torch.constant.none
# CHECK: %[[NONE1:.*]] = torch.constant.none
# CHECK: %[[NONE2:.*]] = torch.constant.none
# CHECK: %[[BOOL_FALSE:.*]] = torch.constant.bool false
# CHECK: %[[NONE3:.*]] = torch.constant.none
# CHECK: %[[EMPTY_LIKE:.*]] = torch.aten.empty_like %[[ARG0]], %[[NONE0]], %[[NONE1]], %[[NONE2]], %[[BOOL_FALSE]], %[[NONE3]] : !torch.vtensor<[3,4],f32>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
# CHECK: return %[[EMPTY_LIKE]] : !torch.vtensor<[3,4],f32>
@dynamo.optimize(my_backend)
def literals_bool(x):
return torch.ops.aten.empty_like(x, pin_memory=False)
set_model_name("literals_bool")
literals_bool(torch.randn(3, 4))
# CHECK-LABEL: func.func @literals_float(
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
# CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[ARG0]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f32>
# CHECK: return %[[UNIFORM]] : !torch.vtensor<[3,4],f32>
@dynamo.optimize(my_backend)
def literals_float(x):
return torch.ops.aten.uniform(x, 0.0, 1.0)
set_model_name("literals_float")
literals_float(torch.randn(3, 4))
# CHECK-LABEL: func.func @literals_str(
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK: %[[STR_TANH:.*]] = torch.constant.str "tanh"
# CHECK: %[[GELU:.*]] = torch.aten.gelu %[[ARG0]], %[[STR_TANH]] : !torch.vtensor<[3,4],f32>, !torch.str -> !torch.vtensor<[3,4],f32>
# CHECK: return %[[GELU]] : !torch.vtensor<[3,4],f32>
@dynamo.optimize(my_backend)
def literals_str(x):
return torch.ops.aten.gelu(x, approximate="tanh")
set_model_name("literals_str")
literals_str(torch.randn(3, 4))

View File

@ -12,8 +12,8 @@ import tempfile
from torch._functorch.compile_utils import strip_overloads
import torch
import torch.fx
from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
@ -248,6 +248,64 @@ BACKEND_LEGAL_OPS = {
}
def _canon_extra_library(extra_library):
extra_library_file_name = ""
if len(extra_library) != 0:
extra_library_dict = {}
for library_func in extra_library:
extra_library_dict[library_func.__name__] = library_func
mlir_library = generate_library(extra_library_dict)
extra_library_file_name = \
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
with open(extra_library_file_name, "w") as f:
f.write(mlir_library)
return extra_library_file_name
def _lower_mlir_module(verbose, output_type, module):
if verbose:
print("\n====================")
print("Torch Backend IR")
print(module)
if output_type == OutputType.TORCH:
return module
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
"Lowering Torch Backend IR -> TOSA Backend IR")
if verbose:
print("\n====================")
print("TOSA Backend IR")
print(module)
return module
if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
if verbose:
print("\n====================")
print("LINALG Backend IR")
print(module)
return module
elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose:
print("\n====================")
print("StableHLO Backend IR")
print(module)
return module
raise Exception(f"Unknown OutputType: {output_type}")
def compile(model: torch.nn.Module,
example_args: _example_args,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
@ -290,18 +348,7 @@ def compile(model: torch.nn.Module,
An MLIR module that contains the converted model in the specified
output type.
"""
extra_library_file_name = ""
if len(extra_library) != 0:
extra_library_dict = {}
for library_func in extra_library:
extra_library_dict[library_func.__name__] = library_func
mlir_library = generate_library(extra_library_dict)
extra_library_file_name = \
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
with open(extra_library_file_name, "w") as f:
f.write(mlir_library)
extra_library_file_name = _canon_extra_library(extra_library)
output_type = OutputType.get(output_type)
example_args = ExampleArgs.get(example_args)
if ignore_traced_shapes and not use_tracing:
@ -394,44 +441,4 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
"Lowering TorchScript IR -> Torch Backend IR",
)
if verbose:
print("\n====================")
print("Torch Backend IR")
print(mb.module)
if output_type == OutputType.TORCH:
return mb.module
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
mb.module,
"builtin.module(torch-backend-to-tosa-backend-pipeline)",
"Lowering Torch Backend IR -> TOSA Backend IR")
if verbose:
print("\n====================")
print("TOSA Backend IR")
print(mb.module)
return mb.module
if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
mb.module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
if verbose:
print("\n====================")
print("LINALG Backend IR")
print(mb.module)
return mb.module
elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report(
mb.module,
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose:
print("\n====================")
print("StableHLO Backend IR")
print(mb.module)
return mb.module
raise Exception(f"Unknown OutputType: {output_type}")
return _lower_mlir_module(verbose, output_type, mb.module)

View File

@ -0,0 +1,442 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import pdb
# This file implements a pure-Python importer from a restricted subset of
# FX IR into MLIR.
#
# As described in the
# [long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#refactoring-the-frontend),
# the goal is to import directly in the Torch-MLIR backend contract by
# using the available PyTorch infra for doing functionalization,
# shape inference, etc. Thus, this importer imports a very specific subset
# of the possible FX IR that is co-designed with the PyTorch infra that produces
# the FX graph -- see the `torch_mlir.dynamo` module for that, and see the
# `_verify_fx_graph_conforms_to_subset` function for the operational definition.
#
# In fact, because of the generality of FX IR (e.g. the use of raw Python
# callables as node.target), there is really no well-defined way to implement a
# general FX -> MLIR importer. Reuse or extension of this code for other
# FX -> MLIR use cases should be done carefully, and likely will involve
# introducing new concepts or abstractions into the import process.
from typing import Dict, Tuple
import operator
import re
import torch
import torch_mlir.ir as ir
import torch_mlir.dialects.func as func_dialect
import torch_mlir.dialects.torch as torch_dialect
def _is_valid_meta_val(val):
# We currently allow only FakeTensor's or lists of FakeTensor's
# as meta['val']. However, this can potentially change also hold a SymInt
# in the future. See:
# https://github.com/pytorch/pytorch/issues/90839#issuecomment-1352856661
if isinstance(val, torch._subclasses.FakeTensor):
return True
if isinstance(val, (tuple, list)):
return all(isinstance(x, torch._subclasses.FakeTensor) for x in val)
return False
def _verify_fx_graph_conforms_to_subset(g: torch.fx.Graph):
# TODO: Report errors with source locations if possible.
def _check_meta_val(node):
if "val" not in node.meta:
raise Exception(f"Unsupported: missing node.meta['val']: {node}")
if not _is_valid_meta_val(node.meta["val"]):
raise Exception(
f"Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: {node}; {node.meta['val']}"
)
for node in g.nodes:
if node.op not in ("placeholder", "call_function", "output"):
raise Exception(f"Unsupported op: {node.op}")
if node.op == "placeholder":
_check_meta_val(node)
if node.op == "call_function":
_check_meta_val(node)
# We only support OpOverload for computations because the `torch`
# dialect ops model the full qualified op name, including overload.
# We also support operator.getitem because that is how multiple
# results are modeled.
if isinstance(node.target, torch._ops.OpOverload):
for type_ in (r.type for r in node.target._schema.returns):
if isinstance(type_, torch.TensorType):
continue
raise Exception(
f"Unsupported: return type {type_} in schema for {node.target}"
)
if len(node.args) != len(node.target._schema.arguments):
assert len(node.args) < len(node.target._schema.arguments)
for i, argument in enumerate(
node.target._schema.arguments[len(node.args):]):
if not argument.has_default_value():
raise Exception(
f"Unsupported: missing default value for argument {i} in schema for {node.target}"
)
continue
if node.target is operator.getitem:
continue
raise Exception(f"Unsupported call_function target: {node.target}")
# ==============================================================================
# Type import
# ==============================================================================
def _torch_type_to_mlir_type_string(t: torch.Type) -> str:
# This is weird -- for Node's, since they are untyped, we use the
# node.meta['val'] to get the type (which is a tensor type with sizes and
# dtype).
# But for things that are associated with a schema, we use the schema to get
# the type. This creates problems for things like a list<tensor> because
# then we don't have sizes or dtypes available.
if isinstance(t, torch.ListType):
return f"list<{_torch_type_to_mlir_type_string(t.getElementType())}>"
if isinstance(t, torch.BoolType):
return "bool"
if isinstance(t, torch.IntType):
return "int"
if isinstance(t, torch.FloatType):
return "float"
if isinstance(t, torch.StringType):
return "string"
if isinstance(t, torch.TensorType):
return "vtensor"
if isinstance(t, torch.OptionalType):
return f"optional<{_torch_type_to_mlir_type_string(t.getElementType())}>"
raise Exception(f"Unsupported type: {t}")
def _torch_type_to_mlir_type(t: torch.Type):
return ir.Type.parse(f"!torch.{_torch_type_to_mlir_type_string(t)}")
def _convert_dtype_to_mlir_type(dtype: torch.dtype) -> str:
# See the table in TorchTypes.td:AnyTorchTensorType's documentation.
if dtype == torch.float16:
return "f16"
if dtype == torch.bfloat16:
return "bf16"
if dtype == torch.float32:
return "f32"
if dtype == torch.float64:
return "f64"
if dtype == torch.uint8:
return "ui8"
if dtype == torch.int8:
return "si8"
if dtype == torch.int16:
return "si16"
if dtype == torch.int32:
return "si32"
if dtype == torch.int64:
return "si64"
if dtype == torch.bool:
return "i1"
if dtype == torch.qint8:
return "!torch.qint8"
if dtype == torch.quint8:
return "!torch.quint8"
if dtype == torch.complex64:
return "complex<f64>"
raise Exception(f"Unsupported dtype: {dtype}")
def _import_fake_tensor_as_mlir_type(
fake_tensor: torch._subclasses.FakeTensor) -> ir.Type:
# TODO: Find story for how to get dynamically shaped tensors here.
shape = ",".join(str(d) for d in fake_tensor.shape)
dtype = _convert_dtype_to_mlir_type(fake_tensor.dtype)
return ir.Type.parse(f"!torch.vtensor<[{shape}],{dtype}>")
def _mlir_types_for_node(node: torch.fx.Node) -> ir.Type:
if isinstance(node.meta["val"], (tuple, list)):
return [_import_fake_tensor_as_mlir_type(v) for v in node.meta["val"]]
return [_import_fake_tensor_as_mlir_type(node.meta["val"])]
def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType:
input_types = []
for node in g.nodes:
if node.op == "placeholder":
input_types.append(_mlir_types_for_node(node)[0])
if node.op == "output":
# TODO(DNS): Test this or add verifier that it can't happen.
result_types = torch.fx.map_arg(
node.args[0], lambda n: _mlir_types_for_node(n)[0])
# Note: We import directly to the backend contract -- multiple results
# are modeled with func.func native multiple results rather than as a
# singleton value / tuple.
return ir.FunctionType.get(input_types, result_types)
# ==============================================================================
# FX Graph import
# ==============================================================================
DTYPE_TO_INT = {
# TODO(DNS): Fill in from AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
torch.uint8:
0,
torch.int8:
1,
torch.int16:
2,
torch.int32:
3,
torch.int64:
4,
torch.float16:
5,
torch.float32:
6,
torch.float64:
7,
# torch.complex_half 8
torch.complex32:
9,
torch.complex64:
10,
torch.bool:
11,
torch.qint8:
12,
torch.quint8:
13,
# torch.qint32 14
torch.bfloat16:
15,
}
MEMORY_FORMAT_TO_INT = {
# https://github.com/pytorch/pytorch/c10/core/MemoryFormat.h#L28
torch.contiguous_format:
0,
torch.preserve_format:
1,
torch.channels_last:
2,
torch.channels_last_3d:
3,
}
LAYOUT_TO_INT = {
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_layouts.cpp
torch.strided:
0,
torch.sparse_coo:
1,
torch.sparse_csr:
2,
torch.sparse_csc:
3,
torch.sparse_bsr:
4,
torch.sparse_bsc:
5,
}
def _mlir_location_for_node(node: torch.fx.Node) -> ir.Location:
stack_trace = node.stack_trace
if stack_trace is None:
return ir.Location.unknown()
# TODO: Avoid needing to regex match this.
# https://github.com/pytorch/pytorch/issues/91000
m = re.search(r"""File "([^"]+)", line ([0-9]+),""", node.stack_trace)
filename, line = m.group(1), int(m.group(2))
return ir.Location.file(filename, line, col=0)
class _FXGraphImporter:
def __init__(self, g: torch.fx.Graph, func_name: str):
self._g = g
self._func_name = func_name
# For each node, we track a mapping to MLIR Value's.
# Technically all Node's have a single output (which can be a tuple of
# values in case of multiple returns), but we treat them as having
# multiple returns directly. This matches how the Node's
# node.meta['val'] is set up, since it contains a list with multiple
# FakeTensor's in case of a tuple return with multiple elements.
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
self._module = ir.Module.create(ir.Location.unknown())
self._module.operation.attributes[
"torch.debug_module_name"] = ir.StringAttr.get(func_name)
function_type = _extract_function_type_from_graph(g)
func = func_dialect.FuncOp(
func_name,
function_type,
loc=ir.Location.unknown(), # TODO: Can we do better?
ip=ir.InsertionPoint(self._module.body),
)
self._body_block = ir.Block.create_at_start(func.body,
function_type.inputs)
def import_graph(self) -> ir.Module:
with ir.InsertionPoint(self._body_block):
num_placeholders_seen = 0
for node in self._g.nodes:
with _mlir_location_for_node(node):
if node.op == "placeholder":
self._env[(
node, 0
)] = self._body_block.arguments[num_placeholders_seen]
num_placeholders_seen += 1
if node.op == "call_function":
if node.target is operator.getitem:
self._env[(node, 0)] = self._env[(node.args[0],
node.args[1])]
else:
self._import_op_overload_call(node)
if node.op == "output":
# Note that the output node is a singleton tuple holding
# a tuple of return values (without the single-element special
# case)
# DNS: Test or verify no literals as results.
operands = [
self._import_argument(arg) for arg in node.args[0]
]
func_dialect.ReturnOp(operands)
return self._module
def _import_op_overload_call(self, node: torch.fx.Node):
assert node.op == "call_function"
assert isinstance(node.target, torch._ops.OpOverload)
schema = node.target._schema
# Extract the `torch` dialect op name.
namespace, _, unqualified_name = schema.name.partition("::")
mlir_op_name = f"torch.{namespace}.{unqualified_name}"
if schema.overload_name != "":
mlir_op_name += f".{schema.overload_name}"
# DNS: Unregistered ops
assert ir.Context.current.is_registered_operation(
mlir_op_name), f"Unregistered operation: {mlir_op_name}"
# Construct the Operation.
result_types = _mlir_types_for_node(node)
operands = []
# `schema.arguments` is a bit confusing in this context, since
# `Argument` is the term that FX uses analogous to mlir "Value". It is
# more precise to call them "formal parameters".
for i, parameter in enumerate(node.target._schema.arguments):
if parameter.kwarg_only and parameter.name in node.kwargs:
arg = node.kwargs[parameter.name]
elif i < len(node.args):
arg = node.args[i]
else:
arg = parameter.default_value
operands.append(self._import_argument(arg, parameter.type))
operation = ir.Operation.create(
mlir_op_name,
results=result_types,
operands=operands,
)
for i, value in enumerate(operation.results):
self._env[(node, i)] = value
def _import_argument(self,
arg: torch.fx.node.Argument,
expected_type_for_literal=None) -> ir.Value:
"""Import an FX `Argument`, which is analogous to an MLIR `Value`.
Args:
arg: The FX `Argument` to import.
expected_type_for_literal: If `arg` is a literal (such as a Python
`int` or `float` object), this is the expected JIT IR type. This
allows disambiguating certain cases, such as importing an optional
type.
Returns:
The imported MLIR `Value`.
"""
if isinstance(arg, torch.fx.Node):
return self._env[(arg, 0)]
assert expected_type_for_literal is not None
return self._import_literal(arg, expected_type_for_literal)
def _import_literal(self, arg: torch.fx.node.Argument,
expected_type) -> ir.Value:
if arg is None:
return torch_dialect.ConstantNoneOp().result
if isinstance(expected_type, torch.OptionalType):
return self._import_argument(arg, expected_type.getElementType())
if isinstance(arg, bool):
return torch_dialect.ConstantBoolOp(
ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)).result
if isinstance(arg, int):
return torch_dialect.ConstantIntOp(
ir.IntegerAttr.get(ir.IntegerType.get_signless(64),
arg)).result
if isinstance(arg, float):
return torch_dialect.ConstantFloatOp(
ir.FloatAttr.get_f64(arg)).result
if isinstance(arg, str):
return torch_dialect.ConstantStrOp(ir.StringAttr.get(arg)).result
if isinstance(arg, torch.dtype):
assert isinstance(expected_type, torch.IntType)
return self._import_argument(DTYPE_TO_INT[arg], expected_type)
if isinstance(arg, torch.device):
# TODO(DNS): Device index? arg.index
return torch_dialect.ConstantDeviceOp(ir.StringAttr.get(
arg.type)).result
if isinstance(arg, torch.memory_format):
assert isinstance(expected_type, torch.IntType)
return self._import_argument(MEMORY_FORMAT_TO_INT[arg],
expected_type)
if isinstance(arg, torch.layout):
assert isinstance(expected_type, torch.IntType)
return self._import_argument(LAYOUT_TO_INT[arg], expected_type)
if isinstance(arg, list):
assert isinstance(expected_type, torch.ListType)
element_type = expected_type.getElementType()
if isinstance(element_type, torch.TensorType):
assert all(
torch.fx.node.map_aggregate(
arg, lambda a: _is_valid_meta_val(a.meta.get("val"))))
els = [self._env[e, 0] for e in arg]
else:
element_type = _torch_type_to_mlir_type(element_type)
els = [
self._import_argument(e, element_type) for e in arg
]
# import pydevd_pycharm
# pydevd_pycharm.settrace('localhost', port=8888, stdoutToServer=True, stderrToServer=True)
return torch_dialect.PrimListConstructOp(
_torch_type_to_mlir_type(expected_type),
els,
).result
raise Exception(f"Unsupported literal: {arg}")
def import_fx_graph_as_func(g: torch.fx.Graph, func_name: str) -> ir.Module:
"""Imports the given FX graph as a function in a new MLIR module.
Args:
g: The FX graph to import.
func_name: The sym_name of the `func.func` to import the graph into.
Returns:
A new MLIR module containing the imported function.
"""
# Note that this function imports a fx.Graph instead of an fx.GraphModule.
# The reason is that the supported subset only involves stateless
# fx.Graph's, so the state held on the fx.GraphModule is not necessary.
_verify_fx_graph_conforms_to_subset(g)
with ir.Context() as context:
torch_dialect.register_dialect(context)
return _FXGraphImporter(g, func_name).import_graph()

View File

@ -2,96 +2,149 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import List
import numpy
from typing import List, Union, Optional, Sequence
import numpy as np
import torch
import torch._dynamo as dynamo
import torch_mlir
import torch_mlir.dynamo
from torch_mlir.dynamo import make_simple_dynamo_backend
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
import torch.utils._pytree as pytree
from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import (
make_boxed_compiler,
get_aot_compilation_context,
set_model_name,
)
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
from torch_mlir.dynamo import _get_decomposition_table
from torch_mlir import (
_example_args,
OutputType,
BACKEND_LEGAL_OPS,
run_pipeline_with_repro_report,
_lower_mlir_module,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
)
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
def refine_result_type(_result):
if isinstance(_result, tuple):
return tuple(refine_result_type(x) for x in _result)
elif isinstance(_result, np.ndarray):
return torch.from_numpy(_result)
elif isinstance(_result, (bool, int, float)):
return _result
else:
raise ValueError(f"Unhandled return type {type(_result)}")
def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
for node in fx_graph.graph.nodes:
if node.op == "output":
assert len(node.args) == 1, "Output node must have a single argument"
assert len(
node.args) == 1, "Output node must have a single argument"
node_arg = node.args[0]
if node_arg != ():
return False
return True
@make_simple_dynamo_backend
def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
# Use the LinalgOnTensors backend, since it is the most complete.
# In theory we could mix and match TorchDynamo with the other backends,
# since they all lower through the same backend contract.
# For now, testing-wise, it doesn't make sense to test those configurations.
# We really just want to check the TorchDynamo frontend.
#
# Longer-term we will need to do something more sophisticated here.
# As per the long-term roadmap:
# https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#refactoring-the-frontend
# We will eventually have a configuration that uses new PyTorch infra and
# skips the entire "frontend" part. We currently don't have any code
# for that right now since it is still very early stages, but eventually
# this Config should test that path (and maybe the current behavior can
# be moved to a `legacy_frontend_via_torchdynamo` config).
def jit(
model: torch.nn.Module,
example_args: _example_args,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None,
verbose: bool = False,
):
if extra_library is None:
extra_library = []
import torch._dynamo as dynamo
mlir_module = None
extra_library_file_name = _canon_extra_library(extra_library)
output_type = OutputType.get(output_type)
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception("`backend_legal_ops` is only valid with the "
"`torch` output type")
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
@make_boxed_compiler
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
_example_inputs: List[torch.Tensor]):
# Torch-MLIR does not support returning an empty tuple. The reason is
# that both returning an empty tuple and returning `None` results in MLIR
# functions that have as a return type `()`. In other words, there is no
# way of differentiating between the two.
assert not _returns_empty_tuple(fx_graph), "encountered graph that does not return anything"
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors")
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)
loaded = backend.load(compiled)
nonlocal mlir_module
*_, model_name, nth_graph = get_aot_compilation_context()
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
return gm
def compiled_callable(*inputs):
def refine_result_type(_result):
if isinstance(_result, tuple):
return tuple(refine_result_type(x) for x in _result)
elif isinstance(_result, numpy.ndarray):
return torch.from_numpy(_result)
elif isinstance(_result, (bool, int, float)):
return _result
else:
raise ValueError(f"Unhandled return type {type(_result)}")
inputs = [x.numpy() for x in inputs]
result = loaded.forward(*inputs)
return refine_result_type(result)
return compiled_callable
my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend,
decompositions=_get_decomposition_table)
with torch.no_grad():
set_model_name(model.__class__.__name__)
torch._dynamo.reset()
dynamo_f = dynamo.optimize(my_backend, nopython=True)(
lambda method, *inputs: method(*inputs))
dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]),
*example_args)
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
" extra-library=" + extra_library_file_name + "}")
assert mlir_module is not None
run_pipeline_with_repro_report(
mlir_module,
# f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
f"builtin.module(torch-lower-to-backend-contract)",
"Lowering TorchFX IR -> Torch Backend IR",
)
return _lower_mlir_module(verbose, output_type, mlir_module)
class TorchDynamoTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with TorchDynamo"""
def __init__(self):
def __init__(self, backend):
super().__init__()
self.backend = backend
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
return program
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
def item_symbol_that_clones_inputs(*inputs):
cloned_inputs = [x.clone() for x in inputs]
result = getattr(artifact, item.symbol)(*cloned_inputs)
return result
# TODO: Deepcopy the torch.nn.Module, so that if the program is
# stateful then it does not mutate the original compiled program.
result: Trace = []
for item in trace:
f = lambda method, *inputs: method(*inputs)
torch._dynamo.reset()
dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend, nopython=True)(f)
output = dynamo_f(item_symbol_that_clones_inputs, *item.inputs)
module = jit(artifact,
item.inputs,
output_type="linalg-on-tensors")
module = self.backend.compile(module)
backend_module = self.backend.load(module)
params = {
**dict(artifact.named_parameters(remove_duplicate=False)),
**dict(artifact.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
with torch.no_grad():
numpy_inputs = recursively_convert_to_numpy(params_flat +
item.inputs)
outputs = getattr(backend_module,
artifact.__class__.__name__)(*numpy_inputs)
output = refine_result_type(outputs)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,

View File

@ -1914,7 +1914,7 @@ class TensorLiteralModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.t = torch.randint(-5, 5, (2, 3))
self.register_buffer("t", torch.randint(-5, 5, (2, 3)))
@export
@annotate_args([
@ -1937,7 +1937,7 @@ class TensorOpaqueLiteralModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.t = torch.randint(-5, 5, (256, 1024))
self.register_buffer("t", torch.randint(-5, 5, (256, 1024)))
@export
@annotate_args([