From c3cd7471b4831d101780fdd9527feb7cce907bc3 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 12 May 2023 00:46:33 -0500 Subject: [PATCH] Pure-Python FX importer. (#2098) Co-authored-by: Sean Silva --- e2e_testing/main.py | 2 +- e2e_testing/xfail_sets.py | 138 ++++-- python/CMakeLists.txt | 1 + python/test/dynamo_fx_importer/basic.py | 116 +++++ python/torch_mlir/__init__.py | 115 ++--- python/torch_mlir/_dynamo_fx_importer.py | 442 ++++++++++++++++++ .../configs/torchdynamo.py | 173 ++++--- .../torch_mlir_e2e_test/test_suite/basic.py | 8 +- 8 files changed, 836 insertions(+), 159 deletions(-) create mode 100644 python/test/dynamo_fx_importer/basic.py create mode 100644 python/torch_mlir/_dynamo_fx_importer.py diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 91ca0c85f..d0d56fc67 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -106,7 +106,7 @@ def main(): xfail_set = LTC_XFAIL_SET crashing_set = set() elif args.config == "torchdynamo": - config = TorchDynamoTestConfig() + config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = TORCHDYNAMO_XFAIL_SET crashing_set = TORCHDYNAMO_CRASHING_SET diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8ba20cbe1..1fc89ca5f 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -56,20 +56,7 @@ TORCHDYNAMO_XFAIL_SET = { "ElementwiseWhereScalarSelfModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarSelfStaticModule_basic", - # %7 = torch.operator "aten._index_put_impl_.hacked_twin"(%1, %6, %5, %true, %false) : (!torch.tensor<*,f32>, !torch.list, !torch.tensor<*,f32>, !torch.bool, !torch.bool) -> !torch.tensor - "IndexPutImpl1DFloatAccumulateModule_basic", - "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl1DIntAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexPutImpl2DFloatAccumulateModule_basic", - "IndexPutImpl2DFloatNonAccumulateModule_basic", - "IndexPutImpl2DIndexModule_basic", - "IndexPutImpl3DFloatAccumulateModule_basic", - "IndexPutImpl3DFloatNonAccumulateModule_basic", - # https://github.com/llvm/torch-mlir/issues/1611 - # error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible - "Aten_EmbeddingBagExample_basic", # error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal "BernoulliFloatModule_basic", "BernoulliPModule_basic", @@ -77,8 +64,6 @@ TORCHDYNAMO_XFAIL_SET = { "ElementwiseFlattenBroadcastModule_basic", "FlattenRank0Module_basic", "UniformModule_basic", - # error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion - "TModuleRank1_basic", # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", @@ -107,7 +92,7 @@ TORCHDYNAMO_XFAIL_SET = { 'TensorToBoolZeroRank_basic', 'TensorToBool_basic', - # torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} + # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} 'AtenSubFloatModule_basic', 'BoolFloatFalseModule_basic', 'BoolFloatTrueModule_basic', @@ -120,8 +105,10 @@ TORCHDYNAMO_XFAIL_SET = { 'SubFloatModule_basic', 'TensorToFloatZeroRank_basic', 'TensorToFloat_basic', + # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} - # torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} + + # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} 'AddIntModule_basic', 'AtenIntTensorCharDtypeModule_basic', 'BoolIntFalseModule_basic', @@ -138,78 +125,153 @@ TORCHDYNAMO_XFAIL_SET = { 'TensorToInt_basic', 'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic', 'ViewCollapseDynamicWithAtenSizeIntModule_basic', + # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} - # torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)} + # ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)} 'SortIntListReverse_basic', - # torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {} + # ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {} 'SortIntList_basic', - # torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default + # START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default 'AtenFloatScalarModule_basic', 'AtenIntBoolOpModule_basic', 'OneHotModule_basic', 'QuantizedMLP_basic', 'ScalarImplicitFloatModule_basic', 'ScalarImplicitIntModule_basic', + # END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default - # torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default + # START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default 'BincountMinlengthModule_basic', 'BincountModule_basic', 'BincountStaticSizeModule_basic', + # END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool 'BoolFloatConstantModule_basic', 'BoolIntConstantModule_basic', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__ + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__ 'ContainsIntList_False', 'ContainsIntList_True', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all 'AllBoolFalseModule_basic', 'AllBoolTrueModule_basic', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any 'AnyBoolFalseModule_basic', 'AnyBoolTrueModule_basic', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt 'SqrtIntConstantModule_basic', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int + # START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int 'AtenIntBoolOpConstFalseModule_basic', 'AtenIntBoolOpConstTrueModule_basic', 'IntFloatModule_basic', 'PowIntFloatModule_basic', + # END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len 'LenStrModule_basic', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel 'NumelModule_basic', 'NumelZeroRankModule_basic', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max 'PrimMaxIntModule_basic', - # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min + # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min 'PrimMinIntModule_basic', - # empty graph + # START tests failing due to: empty graph in dynamo 'IsFloatingPointFloat_True', 'IsFloatingPointInt_False', 'TorchPrimLoopForLikeModule_basic', 'TorchPrimLoopWhileLikeModule_basic', + # END tests failing due to: empty graph in dynamo - # Forming aten.view_as_real and aten.view_as_imag instead of aten.real and aten.imag op. - # Complex ops + # ERROR due to: backend never runs because of empty frame + 'ConstantBoolParameterModule_basic', + + # START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + "AddCDivModule_basic", + "ElementwiseMulScalarModule_basic", + "ElementwiseMulScalarModule_float", + "NativeGroupNormBackwardModule_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2d_basic", + # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + + # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalarInt64Module_basic", + "ElementwiseAddScalarIntModule_basic", + "MobilenetV3Module_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeGroupNormModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + # END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + + # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "HBC_basic", + + # ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + "ElementwiseDivScalarModule_basic", + + # ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' + "ElementwiseMulScalarModule_int", + + # ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", + + # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode + "ElementwiseDivRoundingModeFloorModule_basic", + "ElementwiseDivRoundingModeTruncModule_basic", + + # ERROR: Exception: Unsupported op: get_attr + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "TensorFloatModule_basic", + "TensorIntModule_basic", + + # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.randn.generator + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + + # START tests failing due to: complex floating point ops "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", + # END tests failing due to: complex floating point ops } -# See https://github.com/llvm/torch-mlir/issues/2050 TORCHDYNAMO_CRASHING_SET = { + # No upstream decompositions. + # %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) + # See also: https://github.com/pytorch/torchdynamo/issues/327 + "Aten_EmbeddingBagExample_basic", + # https://github.com/pytorch/pytorch/issues/100838 + "BaddbmmDifferentDtypesModule_basic", + "FullModuleInt3D_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward3dIntModule_basic", + # See https://github.com/llvm/torch-mlir/issues/2050 "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", @@ -226,11 +288,7 @@ TORCHDYNAMO_CRASHING_SET = { "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "SelectIntNegativeDimAndIndexStaticModule_basic", - "SliceModule_basic", - "SliceNegIdxModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceSizeTwoStepModule_basic", - "SliceStaticModule_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", "TModuleRank2_basic", "ToCopyModule_basic", "TransposeIntModule_basic", diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 0ba37a8ec..3c914df09 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -51,6 +51,7 @@ if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) ADD_TO_PARENT TorchMLIRPythonSources SOURCES __init__.py + _dynamo_fx_importer.py compiler_utils.py dynamo.py ) diff --git a/python/test/dynamo_fx_importer/basic.py b/python/test/dynamo_fx_importer/basic.py new file mode 100644 index 000000000..cea2f639f --- /dev/null +++ b/python/test/dynamo_fx_importer/basic.py @@ -0,0 +1,116 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +from typing import List + +import torch +import torch.fx +import torch._dynamo as dynamo +from torch._dynamo.backends.common import aot_autograd +from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_compilation_context, set_model_name + +from torch_mlir.compiler_utils import TorchMlirCompilerError +from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func +from torch_mlir_e2e_test.configs.torchdynamo import jit + + +@make_boxed_compiler +def my_aot_autograd_backend(gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor]): + print(gm.graph) + *_, model_name, nth_graph = get_aot_compilation_context() + mlir_module = import_fx_graph_as_func(gm.graph, model_name) + print(mlir_module.operation.get_asm(enable_debug_info=True)) + return gm + + +my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend) + + +# CHECK: module attributes {torch.debug_module_name = "basic"} { +# CHECK-NEXT: func.func @basic(%[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> { +# CHECK-NEXT: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> loc(#[[LOC:.*]]) +# CHECK-NEXT: return %[[TANH]] : !torch.vtensor<[3,4],f32> +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: #[[LOC]] = loc("{{.*}}/dynamo_fx_importer/basic.py":{{[0-9]+}}:{{[0-9]+}}) +@dynamo.optimize(my_backend) +def basic(x): + return torch.tanh(x) + + +set_model_name("basic") +basic(torch.randn(3, 4)) + + +# CHECK-LABEL: func.func @literals_list_device_int_none_dtype() -> !torch.vtensor<[3,4],f16> { +# CHECK: %[[INT3:.*]] = torch.constant.int 3 +# CHECK: %[[INT4:.*]] = torch.constant.int 4 +# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list +# CHECK: %[[INT5:.*]] = torch.constant.int 5 +# CHECK: %[[NONE0:.*]] = torch.constant.none +# CHECK: %[[DEVICE_CPU:.*]] = torch.constant.device "cpu" +# CHECK: %[[NONE1:.*]] = torch.constant.none +# CHECK: %[[RANDN:.*]] = torch.aten.randn %[[LIST]], %[[INT5]], %[[NONE0]], %[[DEVICE_CPU]], %[[NONE1]] : !torch.list, !torch.int, !torch.none, !torch.Device, !torch.none -> !torch.vtensor<[3,4],f16> +# CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16> +@dynamo.optimize(my_backend) +def literals_list_device_int_none_dtype(): + return torch.ops.aten.randn([3, 4], + device=torch.device("cpu"), + dtype=torch.float16) + + +set_model_name("literals_list_device_int_none_dtype") +literals_list_device_int_none_dtype() + + +# CHECK-LABEL: func.func @literals_bool( +# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> { +# CHECK: %[[NONE0:.*]] = torch.constant.none +# CHECK: %[[NONE1:.*]] = torch.constant.none +# CHECK: %[[NONE2:.*]] = torch.constant.none +# CHECK: %[[BOOL_FALSE:.*]] = torch.constant.bool false +# CHECK: %[[NONE3:.*]] = torch.constant.none +# CHECK: %[[EMPTY_LIKE:.*]] = torch.aten.empty_like %[[ARG0]], %[[NONE0]], %[[NONE1]], %[[NONE2]], %[[BOOL_FALSE]], %[[NONE3]] : !torch.vtensor<[3,4],f32>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> +# CHECK: return %[[EMPTY_LIKE]] : !torch.vtensor<[3,4],f32> +@dynamo.optimize(my_backend) +def literals_bool(x): + return torch.ops.aten.empty_like(x, pin_memory=False) + + +set_model_name("literals_bool") +literals_bool(torch.randn(3, 4)) + + +# CHECK-LABEL: func.func @literals_float( +# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> { +# CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 +# CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +# CHECK: %[[NONE:.*]] = torch.constant.none +# CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[ARG0]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f32> +# CHECK: return %[[UNIFORM]] : !torch.vtensor<[3,4],f32> +@dynamo.optimize(my_backend) +def literals_float(x): + return torch.ops.aten.uniform(x, 0.0, 1.0) + + +set_model_name("literals_float") +literals_float(torch.randn(3, 4)) + + +# CHECK-LABEL: func.func @literals_str( +# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> { +# CHECK: %[[STR_TANH:.*]] = torch.constant.str "tanh" +# CHECK: %[[GELU:.*]] = torch.aten.gelu %[[ARG0]], %[[STR_TANH]] : !torch.vtensor<[3,4],f32>, !torch.str -> !torch.vtensor<[3,4],f32> +# CHECK: return %[[GELU]] : !torch.vtensor<[3,4],f32> +@dynamo.optimize(my_backend) +def literals_str(x): + return torch.ops.aten.gelu(x, approximate="tanh") + + +set_model_name("literals_str") +literals_str(torch.randn(3, 4)) diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 2d8b9e882..836d3fdfc 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -12,8 +12,8 @@ import tempfile from torch._functorch.compile_utils import strip_overloads import torch +import torch.fx -from torch_mlir.passmanager import PassManager from .compiler_utils import run_pipeline_with_repro_report from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library @@ -248,6 +248,64 @@ BACKEND_LEGAL_OPS = { } +def _canon_extra_library(extra_library): + extra_library_file_name = "" + if len(extra_library) != 0: + extra_library_dict = {} + for library_func in extra_library: + extra_library_dict[library_func.__name__] = library_func + mlir_library = generate_library(extra_library_dict) + + extra_library_file_name = \ + tempfile.gettempdir() + "/custom_op_extra_library.mlir" + with open(extra_library_file_name, "w") as f: + f.write(mlir_library) + return extra_library_file_name + + +def _lower_mlir_module(verbose, output_type, module): + if verbose: + print("\n====================") + print("Torch Backend IR") + print(module) + + if output_type == OutputType.TORCH: + return module + + if output_type == OutputType.TOSA: + run_pipeline_with_repro_report( + module, "builtin.module(torch-backend-to-tosa-backend-pipeline)", + "Lowering Torch Backend IR -> TOSA Backend IR") + if verbose: + print("\n====================") + print("TOSA Backend IR") + print(module) + return module + + if output_type == OutputType.LINALG_ON_TENSORS: + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") + if verbose: + print("\n====================") + print("LINALG Backend IR") + print(module) + return module + + elif output_type == OutputType.STABLEHLO: + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", + "Lowering Torch Backend IR -> StableHLO Backend IR") + if verbose: + print("\n====================") + print("StableHLO Backend IR") + print(module) + return module + raise Exception(f"Unknown OutputType: {output_type}") + + def compile(model: torch.nn.Module, example_args: _example_args, output_type: Union[str, "OutputType"] = OutputType.TORCH, @@ -290,18 +348,7 @@ def compile(model: torch.nn.Module, An MLIR module that contains the converted model in the specified output type. """ - extra_library_file_name = "" - if len(extra_library) != 0: - extra_library_dict = {} - for library_func in extra_library: - extra_library_dict[library_func.__name__] = library_func - mlir_library = generate_library(extra_library_dict) - - extra_library_file_name = \ - tempfile.gettempdir() + "/custom_op_extra_library.mlir" - with open(extra_library_file_name, "w") as f: - f.write(mlir_library) - + extra_library_file_name = _canon_extra_library(extra_library) output_type = OutputType.get(output_type) example_args = ExampleArgs.get(example_args) if ignore_traced_shapes and not use_tracing: @@ -394,44 +441,4 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: "Lowering TorchScript IR -> Torch Backend IR", ) - if verbose: - print("\n====================") - print("Torch Backend IR") - print(mb.module) - - if output_type == OutputType.TORCH: - return mb.module - - if output_type == OutputType.TOSA: - run_pipeline_with_repro_report( - mb.module, - "builtin.module(torch-backend-to-tosa-backend-pipeline)", - "Lowering Torch Backend IR -> TOSA Backend IR") - if verbose: - print("\n====================") - print("TOSA Backend IR") - print(mb.module) - return mb.module - - if output_type == OutputType.LINALG_ON_TENSORS: - run_pipeline_with_repro_report( - mb.module, - "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", - "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") - if verbose: - print("\n====================") - print("LINALG Backend IR") - print(mb.module) - return mb.module - - elif output_type == OutputType.STABLEHLO: - run_pipeline_with_repro_report( - mb.module, - "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", - "Lowering Torch Backend IR -> StableHLO Backend IR") - if verbose: - print("\n====================") - print("StableHLO Backend IR") - print(mb.module) - return mb.module - raise Exception(f"Unknown OutputType: {output_type}") + return _lower_mlir_module(verbose, output_type, mb.module) diff --git a/python/torch_mlir/_dynamo_fx_importer.py b/python/torch_mlir/_dynamo_fx_importer.py new file mode 100644 index 000000000..5755b5118 --- /dev/null +++ b/python/torch_mlir/_dynamo_fx_importer.py @@ -0,0 +1,442 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +import pdb +# This file implements a pure-Python importer from a restricted subset of +# FX IR into MLIR. +# +# As described in the +# [long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#refactoring-the-frontend), +# the goal is to import directly in the Torch-MLIR backend contract by +# using the available PyTorch infra for doing functionalization, +# shape inference, etc. Thus, this importer imports a very specific subset +# of the possible FX IR that is co-designed with the PyTorch infra that produces +# the FX graph -- see the `torch_mlir.dynamo` module for that, and see the +# `_verify_fx_graph_conforms_to_subset` function for the operational definition. +# +# In fact, because of the generality of FX IR (e.g. the use of raw Python +# callables as node.target), there is really no well-defined way to implement a +# general FX -> MLIR importer. Reuse or extension of this code for other +# FX -> MLIR use cases should be done carefully, and likely will involve +# introducing new concepts or abstractions into the import process. + +from typing import Dict, Tuple + +import operator +import re + +import torch + +import torch_mlir.ir as ir +import torch_mlir.dialects.func as func_dialect +import torch_mlir.dialects.torch as torch_dialect + + +def _is_valid_meta_val(val): + # We currently allow only FakeTensor's or lists of FakeTensor's + # as meta['val']. However, this can potentially change also hold a SymInt + # in the future. See: + # https://github.com/pytorch/pytorch/issues/90839#issuecomment-1352856661 + if isinstance(val, torch._subclasses.FakeTensor): + return True + if isinstance(val, (tuple, list)): + return all(isinstance(x, torch._subclasses.FakeTensor) for x in val) + return False + + +def _verify_fx_graph_conforms_to_subset(g: torch.fx.Graph): + # TODO: Report errors with source locations if possible. + def _check_meta_val(node): + if "val" not in node.meta: + raise Exception(f"Unsupported: missing node.meta['val']: {node}") + if not _is_valid_meta_val(node.meta["val"]): + raise Exception( + f"Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: {node}; {node.meta['val']}" + ) + + for node in g.nodes: + if node.op not in ("placeholder", "call_function", "output"): + raise Exception(f"Unsupported op: {node.op}") + if node.op == "placeholder": + _check_meta_val(node) + if node.op == "call_function": + _check_meta_val(node) + # We only support OpOverload for computations because the `torch` + # dialect ops model the full qualified op name, including overload. + # We also support operator.getitem because that is how multiple + # results are modeled. + if isinstance(node.target, torch._ops.OpOverload): + for type_ in (r.type for r in node.target._schema.returns): + if isinstance(type_, torch.TensorType): + continue + raise Exception( + f"Unsupported: return type {type_} in schema for {node.target}" + ) + if len(node.args) != len(node.target._schema.arguments): + assert len(node.args) < len(node.target._schema.arguments) + for i, argument in enumerate( + node.target._schema.arguments[len(node.args):]): + if not argument.has_default_value(): + raise Exception( + f"Unsupported: missing default value for argument {i} in schema for {node.target}" + ) + continue + if node.target is operator.getitem: + continue + raise Exception(f"Unsupported call_function target: {node.target}") + + +# ============================================================================== +# Type import +# ============================================================================== + + +def _torch_type_to_mlir_type_string(t: torch.Type) -> str: + # This is weird -- for Node's, since they are untyped, we use the + # node.meta['val'] to get the type (which is a tensor type with sizes and + # dtype). + # But for things that are associated with a schema, we use the schema to get + # the type. This creates problems for things like a list because + # then we don't have sizes or dtypes available. + if isinstance(t, torch.ListType): + return f"list<{_torch_type_to_mlir_type_string(t.getElementType())}>" + if isinstance(t, torch.BoolType): + return "bool" + if isinstance(t, torch.IntType): + return "int" + if isinstance(t, torch.FloatType): + return "float" + if isinstance(t, torch.StringType): + return "string" + if isinstance(t, torch.TensorType): + return "vtensor" + if isinstance(t, torch.OptionalType): + return f"optional<{_torch_type_to_mlir_type_string(t.getElementType())}>" + raise Exception(f"Unsupported type: {t}") + + +def _torch_type_to_mlir_type(t: torch.Type): + return ir.Type.parse(f"!torch.{_torch_type_to_mlir_type_string(t)}") + + +def _convert_dtype_to_mlir_type(dtype: torch.dtype) -> str: + # See the table in TorchTypes.td:AnyTorchTensorType's documentation. + if dtype == torch.float16: + return "f16" + if dtype == torch.bfloat16: + return "bf16" + if dtype == torch.float32: + return "f32" + if dtype == torch.float64: + return "f64" + if dtype == torch.uint8: + return "ui8" + if dtype == torch.int8: + return "si8" + if dtype == torch.int16: + return "si16" + if dtype == torch.int32: + return "si32" + if dtype == torch.int64: + return "si64" + if dtype == torch.bool: + return "i1" + if dtype == torch.qint8: + return "!torch.qint8" + if dtype == torch.quint8: + return "!torch.quint8" + if dtype == torch.complex64: + return "complex" + + + raise Exception(f"Unsupported dtype: {dtype}") + + +def _import_fake_tensor_as_mlir_type( + fake_tensor: torch._subclasses.FakeTensor) -> ir.Type: + # TODO: Find story for how to get dynamically shaped tensors here. + shape = ",".join(str(d) for d in fake_tensor.shape) + dtype = _convert_dtype_to_mlir_type(fake_tensor.dtype) + return ir.Type.parse(f"!torch.vtensor<[{shape}],{dtype}>") + + +def _mlir_types_for_node(node: torch.fx.Node) -> ir.Type: + if isinstance(node.meta["val"], (tuple, list)): + return [_import_fake_tensor_as_mlir_type(v) for v in node.meta["val"]] + return [_import_fake_tensor_as_mlir_type(node.meta["val"])] + + +def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType: + input_types = [] + for node in g.nodes: + if node.op == "placeholder": + input_types.append(_mlir_types_for_node(node)[0]) + if node.op == "output": + # TODO(DNS): Test this or add verifier that it can't happen. + result_types = torch.fx.map_arg( + node.args[0], lambda n: _mlir_types_for_node(n)[0]) + # Note: We import directly to the backend contract -- multiple results + # are modeled with func.func native multiple results rather than as a + # singleton value / tuple. + return ir.FunctionType.get(input_types, result_types) + + +# ============================================================================== +# FX Graph import +# ============================================================================== + +DTYPE_TO_INT = { + # TODO(DNS): Fill in from AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS + torch.uint8: + 0, + torch.int8: + 1, + torch.int16: + 2, + torch.int32: + 3, + torch.int64: + 4, + torch.float16: + 5, + torch.float32: + 6, + torch.float64: + 7, + # torch.complex_half 8 + torch.complex32: + 9, + torch.complex64: + 10, + torch.bool: + 11, + torch.qint8: + 12, + torch.quint8: + 13, + # torch.qint32 14 + torch.bfloat16: + 15, +} + +MEMORY_FORMAT_TO_INT = { + # https://github.com/pytorch/pytorch/c10/core/MemoryFormat.h#L28 + torch.contiguous_format: + 0, + torch.preserve_format: + 1, + torch.channels_last: + 2, + torch.channels_last_3d: + 3, +} + +LAYOUT_TO_INT = { + # https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_layouts.cpp + torch.strided: + 0, + torch.sparse_coo: + 1, + torch.sparse_csr: + 2, + torch.sparse_csc: + 3, + torch.sparse_bsr: + 4, + torch.sparse_bsc: + 5, +} + + +def _mlir_location_for_node(node: torch.fx.Node) -> ir.Location: + stack_trace = node.stack_trace + if stack_trace is None: + return ir.Location.unknown() + # TODO: Avoid needing to regex match this. + # https://github.com/pytorch/pytorch/issues/91000 + m = re.search(r"""File "([^"]+)", line ([0-9]+),""", node.stack_trace) + filename, line = m.group(1), int(m.group(2)) + return ir.Location.file(filename, line, col=0) + + +class _FXGraphImporter: + + def __init__(self, g: torch.fx.Graph, func_name: str): + self._g = g + self._func_name = func_name + # For each node, we track a mapping to MLIR Value's. + # Technically all Node's have a single output (which can be a tuple of + # values in case of multiple returns), but we treat them as having + # multiple returns directly. This matches how the Node's + # node.meta['val'] is set up, since it contains a list with multiple + # FakeTensor's in case of a tuple return with multiple elements. + self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {} + self._module = ir.Module.create(ir.Location.unknown()) + self._module.operation.attributes[ + "torch.debug_module_name"] = ir.StringAttr.get(func_name) + function_type = _extract_function_type_from_graph(g) + func = func_dialect.FuncOp( + func_name, + function_type, + loc=ir.Location.unknown(), # TODO: Can we do better? + ip=ir.InsertionPoint(self._module.body), + ) + self._body_block = ir.Block.create_at_start(func.body, + function_type.inputs) + + def import_graph(self) -> ir.Module: + with ir.InsertionPoint(self._body_block): + num_placeholders_seen = 0 + for node in self._g.nodes: + with _mlir_location_for_node(node): + if node.op == "placeholder": + self._env[( + node, 0 + )] = self._body_block.arguments[num_placeholders_seen] + num_placeholders_seen += 1 + if node.op == "call_function": + if node.target is operator.getitem: + self._env[(node, 0)] = self._env[(node.args[0], + node.args[1])] + else: + self._import_op_overload_call(node) + if node.op == "output": + # Note that the output node is a singleton tuple holding + # a tuple of return values (without the single-element special + # case) + # DNS: Test or verify no literals as results. + operands = [ + self._import_argument(arg) for arg in node.args[0] + ] + func_dialect.ReturnOp(operands) + return self._module + + def _import_op_overload_call(self, node: torch.fx.Node): + assert node.op == "call_function" + assert isinstance(node.target, torch._ops.OpOverload) + schema = node.target._schema + + # Extract the `torch` dialect op name. + namespace, _, unqualified_name = schema.name.partition("::") + mlir_op_name = f"torch.{namespace}.{unqualified_name}" + if schema.overload_name != "": + mlir_op_name += f".{schema.overload_name}" + + # DNS: Unregistered ops + assert ir.Context.current.is_registered_operation( + mlir_op_name), f"Unregistered operation: {mlir_op_name}" + + # Construct the Operation. + result_types = _mlir_types_for_node(node) + operands = [] + # `schema.arguments` is a bit confusing in this context, since + # `Argument` is the term that FX uses analogous to mlir "Value". It is + # more precise to call them "formal parameters". + for i, parameter in enumerate(node.target._schema.arguments): + if parameter.kwarg_only and parameter.name in node.kwargs: + arg = node.kwargs[parameter.name] + elif i < len(node.args): + arg = node.args[i] + else: + arg = parameter.default_value + operands.append(self._import_argument(arg, parameter.type)) + operation = ir.Operation.create( + mlir_op_name, + results=result_types, + operands=operands, + ) + for i, value in enumerate(operation.results): + self._env[(node, i)] = value + + def _import_argument(self, + arg: torch.fx.node.Argument, + expected_type_for_literal=None) -> ir.Value: + """Import an FX `Argument`, which is analogous to an MLIR `Value`. + + Args: + arg: The FX `Argument` to import. + expected_type_for_literal: If `arg` is a literal (such as a Python + `int` or `float` object), this is the expected JIT IR type. This + allows disambiguating certain cases, such as importing an optional + type. + Returns: + The imported MLIR `Value`. + """ + if isinstance(arg, torch.fx.Node): + return self._env[(arg, 0)] + assert expected_type_for_literal is not None + return self._import_literal(arg, expected_type_for_literal) + + def _import_literal(self, arg: torch.fx.node.Argument, + expected_type) -> ir.Value: + if arg is None: + return torch_dialect.ConstantNoneOp().result + if isinstance(expected_type, torch.OptionalType): + return self._import_argument(arg, expected_type.getElementType()) + if isinstance(arg, bool): + return torch_dialect.ConstantBoolOp( + ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)).result + if isinstance(arg, int): + return torch_dialect.ConstantIntOp( + ir.IntegerAttr.get(ir.IntegerType.get_signless(64), + arg)).result + if isinstance(arg, float): + return torch_dialect.ConstantFloatOp( + ir.FloatAttr.get_f64(arg)).result + if isinstance(arg, str): + return torch_dialect.ConstantStrOp(ir.StringAttr.get(arg)).result + if isinstance(arg, torch.dtype): + assert isinstance(expected_type, torch.IntType) + return self._import_argument(DTYPE_TO_INT[arg], expected_type) + if isinstance(arg, torch.device): + # TODO(DNS): Device index? arg.index + return torch_dialect.ConstantDeviceOp(ir.StringAttr.get( + arg.type)).result + if isinstance(arg, torch.memory_format): + assert isinstance(expected_type, torch.IntType) + return self._import_argument(MEMORY_FORMAT_TO_INT[arg], + expected_type) + if isinstance(arg, torch.layout): + assert isinstance(expected_type, torch.IntType) + return self._import_argument(LAYOUT_TO_INT[arg], expected_type) + if isinstance(arg, list): + assert isinstance(expected_type, torch.ListType) + element_type = expected_type.getElementType() + if isinstance(element_type, torch.TensorType): + assert all( + torch.fx.node.map_aggregate( + arg, lambda a: _is_valid_meta_val(a.meta.get("val")))) + els = [self._env[e, 0] for e in arg] + + else: + element_type = _torch_type_to_mlir_type(element_type) + els = [ + self._import_argument(e, element_type) for e in arg + ] + + # import pydevd_pycharm + # pydevd_pycharm.settrace('localhost', port=8888, stdoutToServer=True, stderrToServer=True) + return torch_dialect.PrimListConstructOp( + _torch_type_to_mlir_type(expected_type), + els, + ).result + raise Exception(f"Unsupported literal: {arg}") + + +def import_fx_graph_as_func(g: torch.fx.Graph, func_name: str) -> ir.Module: + """Imports the given FX graph as a function in a new MLIR module. + + Args: + g: The FX graph to import. + func_name: The sym_name of the `func.func` to import the graph into. + Returns: + A new MLIR module containing the imported function. + """ + # Note that this function imports a fx.Graph instead of an fx.GraphModule. + # The reason is that the supported subset only involves stateless + # fx.Graph's, so the state held on the fx.GraphModule is not necessary. + _verify_fx_graph_conforms_to_subset(g) + with ir.Context() as context: + torch_dialect.register_dialect(context) + return _FXGraphImporter(g, func_name).import_graph() diff --git a/python/torch_mlir_e2e_test/configs/torchdynamo.py b/python/torch_mlir_e2e_test/configs/torchdynamo.py index f22228fc5..c53227acf 100644 --- a/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -2,96 +2,149 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import List -import numpy +from typing import List, Union, Optional, Sequence + +import numpy as np import torch -import torch._dynamo as dynamo -import torch_mlir -import torch_mlir.dynamo -from torch_mlir.dynamo import make_simple_dynamo_backend -from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend +import torch.utils._pytree as pytree +from torch._dynamo.backends.common import aot_autograd +from torch._functorch.aot_autograd import ( + make_boxed_compiler, + get_aot_compilation_context, + set_model_name, +) +from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func +from torch_mlir.dynamo import _get_decomposition_table +from torch_mlir import ( + _example_args, + OutputType, + BACKEND_LEGAL_OPS, + run_pipeline_with_repro_report, + _lower_mlir_module, + _canon_extra_library, +) +from torch_mlir_e2e_test.configs.utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, +) from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem + +def refine_result_type(_result): + if isinstance(_result, tuple): + return tuple(refine_result_type(x) for x in _result) + elif isinstance(_result, np.ndarray): + return torch.from_numpy(_result) + elif isinstance(_result, (bool, int, float)): + return _result + else: + raise ValueError(f"Unhandled return type {type(_result)}") + + def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: for node in fx_graph.graph.nodes: if node.op == "output": - assert len(node.args) == 1, "Output node must have a single argument" + assert len( + node.args) == 1, "Output node must have a single argument" node_arg = node.args[0] if node_arg != (): return False return True -@make_simple_dynamo_backend -def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): - # Use the LinalgOnTensors backend, since it is the most complete. - # In theory we could mix and match TorchDynamo with the other backends, - # since they all lower through the same backend contract. - # For now, testing-wise, it doesn't make sense to test those configurations. - # We really just want to check the TorchDynamo frontend. - # - # Longer-term we will need to do something more sophisticated here. - # As per the long-term roadmap: - # https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#refactoring-the-frontend - # We will eventually have a configuration that uses new PyTorch infra and - # skips the entire "frontend" part. We currently don't have any code - # for that right now since it is still very early stages, but eventually - # this Config should test that path (and maybe the current behavior can - # be moved to a `legacy_frontend_via_torchdynamo` config). +def jit( + model: torch.nn.Module, + example_args: _example_args, + output_type: Union[str, "OutputType"] = OutputType.TORCH, + backend_legal_ops: Optional[Sequence[str]] = None, + extra_library=None, + verbose: bool = False, +): + if extra_library is None: + extra_library = [] + import torch._dynamo as dynamo - # Torch-MLIR does not support returning an empty tuple. The reason is - # that both returning an empty tuple and returning `None` results in MLIR - # functions that have as a return type `()`. In other words, there is no - # way of differentiating between the two. - assert not _returns_empty_tuple(fx_graph), "encountered graph that does not return anything" + mlir_module = None - mlir_module = torch_mlir.compile( - fx_graph, example_inputs, output_type="linalg-on-tensors") - backend = refbackend.RefBackendLinalgOnTensorsBackend() - compiled = backend.compile(mlir_module) - loaded = backend.load(compiled) + extra_library_file_name = _canon_extra_library(extra_library) + output_type = OutputType.get(output_type) + if backend_legal_ops is not None: + if output_type != OutputType.TORCH: + raise Exception("`backend_legal_ops` is only valid with the " + "`torch` output type") + backend_legal_ops = list(sorted(set(backend_legal_ops))) + else: + backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) - def compiled_callable(*inputs): - def refine_result_type(_result): - if isinstance(_result, tuple): - return tuple(refine_result_type(x) for x in _result) - elif isinstance(_result, numpy.ndarray): - return torch.from_numpy(_result) - elif isinstance(_result, (bool, int, float)): - return _result - else: - raise ValueError(f"Unhandled return type {type(_result)}") - inputs = [x.numpy() for x in inputs] - result = loaded.forward(*inputs) - return refine_result_type(result) - return compiled_callable + @make_boxed_compiler + def my_aot_autograd_backend(gm: torch.fx.GraphModule, + _example_inputs: List[torch.Tensor]): + # Torch-MLIR does not support returning an empty tuple. The reason is + # that both returning an empty tuple and returning `None` results in MLIR + # functions that have as a return type `()`. In other words, there is no + # way of differentiating between the two. + assert not _returns_empty_tuple(gm), "encountered graph that does not return anything" + + nonlocal mlir_module + *_, model_name, nth_graph = get_aot_compilation_context() + mlir_module = import_fx_graph_as_func(gm.graph, model_name) + return gm + + my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend, + decompositions=_get_decomposition_table) + + with torch.no_grad(): + set_model_name(model.__class__.__name__) + torch._dynamo.reset() + dynamo_f = dynamo.optimize(my_backend, nopython=True)( + lambda method, *inputs: method(*inputs)) + dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]), + *example_args) + option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) + + " extra-library=" + extra_library_file_name + "}") + assert mlir_module is not None + run_pipeline_with_repro_report( + mlir_module, + # f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", + f"builtin.module(torch-lower-to-backend-contract)", + "Lowering TorchFX IR -> Torch Backend IR", + ) + + return _lower_mlir_module(verbose, output_type, mlir_module) class TorchDynamoTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with TorchDynamo""" - def __init__(self): + def __init__(self, backend): super().__init__() + self.backend = backend def compile(self, program: torch.nn.Module) -> torch.nn.Module: return program def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: - def item_symbol_that_clones_inputs(*inputs): - cloned_inputs = [x.clone() for x in inputs] - result = getattr(artifact, item.symbol)(*cloned_inputs) - return result - # TODO: Deepcopy the torch.nn.Module, so that if the program is - # stateful then it does not mutate the original compiled program. result: Trace = [] for item in trace: - f = lambda method, *inputs: method(*inputs) - torch._dynamo.reset() - dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend, nopython=True)(f) - output = dynamo_f(item_symbol_that_clones_inputs, *item.inputs) + module = jit(artifact, + item.inputs, + output_type="linalg-on-tensors") + module = self.backend.compile(module) + backend_module = self.backend.load(module) + params = { + **dict(artifact.named_parameters(remove_duplicate=False)), + **dict(artifact.named_buffers(remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = list(params_flat) + with torch.no_grad(): + numpy_inputs = recursively_convert_to_numpy(params_flat + + item.inputs) + outputs = getattr(backend_module, + artifact.__class__.__name__)(*numpy_inputs) + output = refine_result_type(outputs) result.append( TraceItem(symbol=item.symbol, inputs=item.inputs, diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 5e4feb4e2..9b86dfcf1 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -818,7 +818,7 @@ def GatherNegativeDimModule_basic(module, tu: TestUtils): class GatherRandomIndexModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -839,7 +839,7 @@ def GatherRandomIndexModule_basic(module, tu: TestUtils): class Gather2DInputModdule(torch.nn.Module): def __init__(self) -> None: super().__init__() - + @export @annotate_args([ None, @@ -1914,7 +1914,7 @@ class TensorLiteralModule(torch.nn.Module): def __init__(self): super().__init__() torch.manual_seed(0) - self.t = torch.randint(-5, 5, (2, 3)) + self.register_buffer("t", torch.randint(-5, 5, (2, 3))) @export @annotate_args([ @@ -1937,7 +1937,7 @@ class TensorOpaqueLiteralModule(torch.nn.Module): def __init__(self): super().__init__() torch.manual_seed(0) - self.t = torch.randint(-5, 5, (256, 1024)) + self.register_buffer("t", torch.randint(-5, 5, (256, 1024))) @export @annotate_args([