[MHLO] add VerifyMhloBackendContract (#1321)

* [MHLO] add VerifyMhloBackendContract

* guard with macro
pull/1333/head snapshot-20220901.583
Tanyo Kwok 2022-09-01 17:08:17 +08:00 committed by GitHub
parent 729609831c
commit 57d8ec151f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 95 additions and 3 deletions

View File

@ -1,5 +1,9 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()
add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen)
add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc)

View File

@ -39,6 +39,7 @@ void createTorchBackendToTosaBackendPipeline(
void createTorchBackendToMhloBackendPipeline(
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
#endif
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

View File

@ -42,4 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
}
#ifdef TORCH_MLIR_ENABLE_MHLO
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the mhlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
}
#endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCHMLIR_TORCHCONVERSION_PASSES

View File

@ -20,7 +20,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Passes.cpp
VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp
VerifyMhloBackendContract.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms

View File

@ -141,5 +141,9 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that MHLO backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
}
#endif

View File

@ -0,0 +1,73 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifdef TORCH_MLIR_ENABLE_MHLO
#include "PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
namespace {
class VerifyMhloBackendContractPass
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto module = getOperation();
TypeConverter converter;
converter.addConversion([](Type type) -> Type {
auto elemTy = type;
if (isa<TensorType>(type)) {
elemTy = type.cast<TensorType>().getElementType();
}
if (BaseMemRefType::isValidElementType(elemTy))
return type;
return nullptr;
});
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes);
// Basic scalar operations.
target.addLegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<chlo::ChloDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to the MHLO backend contract. "
"See dialect conversion legality information above.";
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
return std::make_unique<VerifyMhloBackendContractPass>();
}
#endif // TORCH_MLIR_ENABLE_MHLO

View File

@ -314,7 +314,10 @@ gentbl_cc_library(
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-pass-decls"],
[
"-gen-pass-decls",
"-DTORCH_MLIR_ENABLE_MHLO",
],
"include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc",
),
],
@ -502,6 +505,7 @@ cc_library(
"lib/Dialect/TorchConversion/Transforms/Passes.cpp",
"lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp",
"lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp",
"lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp",
],
hdrs = [
"include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h",