From 4f173c6e0f1ea35ccee07970f514f9da90e7db8b Mon Sep 17 00:00:00 2001 From: Xiafei Qiu Date: Thu, 10 Nov 2022 18:39:28 +0800 Subject: [PATCH] update llvm tag to a2620e00. (#1567) - also update MHLO to 57ba12a2(branch greencommit/2022-11-07-a2620e00) - change -pass-pipeline format to make tests pass. --- externals/llvm-project | 2 +- externals/mlir-hlo | 2 +- python/torch_mlir/__init__.py | 8 ++++---- .../torch/importer/jit_ir/build_tools/shape_lib_gen.py | 2 +- python/torch_mlir_e2e_test/eager_backends/refbackend.py | 2 +- .../linalg_on_tensors_backends/refbackend.py | 4 ++-- .../mhlo_backends/linalg_on_tensors.py | 2 +- .../tosa_backends/linalg_on_tensors.py | 6 +++--- .../Torch/torch-function-to-torch-backend-pipeline.mlir | 2 +- 9 files changed, 15 insertions(+), 15 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index 74fb770de..a2620e00f 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 74fb770de9399d7258a8eda974c93610cfde698e +Subproject commit a2620e00ffa232a406de3a1d8634beeda86956fd diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 36238f164..57ba12a2a 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 36238f16441cd1a884af988d4400d2ebb0c75bbc +Subproject commit 57ba12a2a1934c3c9fc3cd1580f28f0c233f41d4 diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 41708c63b..6a379752b 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -346,7 +346,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" run_pipeline_with_repro_report( mb.module, - f"torchscript-module-to-torch-backend-pipeline{option_string}", + f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", "Lowering TorchScript IR -> Torch Backend IR", ) @@ -361,7 +361,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: if output_type == OutputType.TOSA: run_pipeline_with_repro_report( mb.module, - "torch-backend-to-tosa-backend-pipeline", + "builtin.module(torch-backend-to-tosa-backend-pipeline)", "Lowering Torch Backend IR -> TOSA Backend IR") if verbose: print("\n====================") @@ -372,7 +372,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: if output_type == OutputType.LINALG_ON_TENSORS: run_pipeline_with_repro_report( mb.module, - "torch-backend-to-linalg-on-tensors-backend-pipeline", + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") if verbose: print("\n====================") @@ -383,7 +383,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: elif output_type == OutputType.MHLO: run_pipeline_with_repro_report( mb.module, - "torch-backend-to-mhlo-backend-pipeline", + "builtin.module(torch-backend-to-mhlo-backend-pipeline)", "Lowering Torch Backend IR -> MHLO Backend IR") if verbose: print("\n====================") diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 1c43e8624..18853df3e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -1261,7 +1261,7 @@ def main(args): for function in torch.jit._state._python_cu.get_functions(): mb.import_function(function) # Clean up the IR a bit before writing it out. - pm = PassManager.parse("canonicalize", context=mb.module.context) + pm = PassManager.parse("builtin.module(canonicalize)", context=mb.module.context) pm.run(mb.module) # Munge the IR a bit to make it more systematically accessible. asm = mb.module.operation.get_asm() diff --git a/python/torch_mlir_e2e_test/eager_backends/refbackend.py b/python/torch_mlir_e2e_test/eager_backends/refbackend.py index 816e8beba..85f4b8c03 100644 --- a/python/torch_mlir_e2e_test/eager_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/eager_backends/refbackend.py @@ -67,7 +67,7 @@ class EagerModeRefBackend(TorchMLIREagerBackend): if module_hash not in self.module_to_refbackend_invoker: run_pipeline_with_repro_report( imported_module, - "torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline", + "builtin.module(torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline)", "EagerMode", ) self.module_to_refbackend_invoker[module_hash] = _ref_backend.load( diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index d7641384c..be59ee907 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -114,7 +114,7 @@ class RefBackendInvoker: return invoke -LOWERING_PIPELINE = ",".join([ +LOWERING_PIPELINE = "builtin.module(" + ",".join([ "func.func(refback-generalize-tensor-pad)", # Bufferize. "func.func(scf-bufferize)", @@ -152,7 +152,7 @@ LOWERING_PIPELINE = ",".join([ "convert-func-to-llvm", "convert-cf-to-llvm", "reconcile-unrealized-casts", -]) +]) + ")" class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend): diff --git a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py index 0a467ef5f..3ac1d6cd6 100644 --- a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py @@ -36,7 +36,7 @@ class LinalgOnTensorsMhloBackend(MhloBackend): """ run_pipeline_with_repro_report( imported_module, - "func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize)", + "builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))", "Lowering MLIR-HLO to Linalg-on-Tensors") return self.refbackend.compile(imported_module) diff --git a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index bd5b2db1e..c5dc6b7e9 100644 --- a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -43,20 +43,20 @@ class LinalgOnTensorsTosaBackend(TosaBackend): # that depend on TOSA as well as TOSA-to-Standard. run_pipeline_with_repro_report( imported_module, - "func.func(tosa-to-arith)", + "builtin.module(func.func(tosa-to-arith))", "Lowering TOSA to Arith") # Named ops must be legalized prior to general tosa-to-linalg run_pipeline_with_repro_report( imported_module, - "func.func(tosa-to-linalg-named)", + "builtin.module(func.func(tosa-to-linalg-named))", "Lowering TOSA to Linalg-on-Tensors for Named Ops") # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them # to arith.constants here before proceeding further. run_pipeline_with_repro_report( imported_module, - "func.func(tosa-to-tensor),func.func(tosa-to-linalg),func.func(tosa-to-arith)", + "builtin.module(func.func(tosa-to-tensor),func.func(tosa-to-linalg),func.func(tosa-to-arith))", "Lowering TOSA to Linalg-on-Tensors") return self.refbackend.compile(imported_module) diff --git a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir index e281ac732..4bcbae306 100644 --- a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir +++ b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt -pass-pipeline='torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax}' -split-input-file %s | FileCheck %s +// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax})' -split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @torch.aten.square func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {