From 57d8ec151f8274da641c9d4dfda59e3e238c7205 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Thu, 1 Sep 2022 17:08:17 +0800 Subject: [PATCH] [MHLO] add VerifyMhloBackendContract (#1321) * [MHLO] add VerifyMhloBackendContract * guard with macro --- .../TorchConversion/Transforms/CMakeLists.txt | 6 +- .../TorchConversion/Transforms/Passes.h | 1 + .../TorchConversion/Transforms/Passes.td | 6 ++ .../TorchConversion/Transforms/CMakeLists.txt | 2 +- .../TorchConversion/Transforms/Passes.cpp | 4 + .../Transforms/VerifyMhloBackendContract.cpp | 73 +++++++++++++++++++ utils/bazel/torch-mlir-overlay/BUILD.bazel | 6 +- 7 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt index 18bb94fa2..00818899f 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,5 +1,9 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +if(TORCH_MLIR_ENABLE_MHLO) + mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) +else() + mlir_tablegen(Passes.h.inc -gen-pass-decls) +endif() add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen) add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc) diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index c4008dec4..f8b749768 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -39,6 +39,7 @@ void createTorchBackendToTosaBackendPipeline( void createTorchBackendToMhloBackendPipeline( OpPassManager &pm, const torch::Torch::TorchLoweringPipelineOptions &options); +std::unique_ptr> createVerifyMhloBackendContractPass(); #endif std::unique_ptr> createFuncBackendTypeConversionPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 8afd9850b..4ce7cdadb 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -42,4 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; } +#ifdef TORCH_MLIR_ENABLE_MHLO +def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the mhlo backend contract"; + let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()"; +} +#endif // TORCH_MLIR_ENABLE_MHLO #endif // TORCHMLIR_TORCHCONVERSION_PASSES diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index f8c3373cb..5412c9ffa 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -20,7 +20,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses Passes.cpp VerifyLinalgOnTensorsBackendContract.cpp VerifyTosaBackendContract.cpp - + VerifyMhloBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index b2f012f03..f7a914164 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -141,5 +141,9 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline( pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); + // Verify that we have lowered to the form that MHLO backends + // expect. This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(TorchConversion::createVerifyMhloBackendContractPass()); } #endif diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp new file mode 100644 index 000000000..8bc19645d --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "PassDetail.h" + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/ChloOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::TorchConversion; + +namespace { +class VerifyMhloBackendContractPass + : public VerifyMhloBackendContractBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto module = getOperation(); + TypeConverter converter; + converter.addConversion([](Type type) -> Type { + auto elemTy = type; + if (isa(type)) { + elemTy = type.cast().getElementType(); + } + if (BaseMemRefType::isValidElementType(elemTy)) + return type; + return nullptr; + }); + + auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); }; + + ConversionTarget target(*context); + + // Structural operations. + target.addDynamicallyLegalOp( + opHasLegalTypes); + // Basic scalar operations. + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + RewritePatternSet patterns(context); + if (failed(applyFullConversion(module, target, std::move(patterns)))) { + // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics + // doesn't unnecessarily spew out the entire module. + emitError(module.getLoc()) + << "Module does not conform to the MHLO backend contract. " + "See dialect conversion legality information above."; + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() { + return std::make_unique(); +} +#endif // TORCH_MLIR_ENABLE_MHLO diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 6704c54bf..6a8cc1e38 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -314,7 +314,10 @@ gentbl_cc_library( strip_include_prefix = "include", tbl_outs = [ ( - ["-gen-pass-decls"], + [ + "-gen-pass-decls", + "-DTORCH_MLIR_ENABLE_MHLO", + ], "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc", ), ], @@ -502,6 +505,7 @@ cc_library( "lib/Dialect/TorchConversion/Transforms/Passes.cpp", "lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp", "lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp", + "lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp", ], hdrs = [ "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h",