mirror of https://github.com/llvm/torch-mlir
[MHLO] add VerifyMhloBackendContract (#1321)
* [MHLO] add VerifyMhloBackendContract * guard with macropull/1333/head snapshot-20220901.583
parent
729609831c
commit
57d8ec151f
|
@ -1,5 +1,9 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
||||
else()
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
endif()
|
||||
add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc)
|
||||
|
|
|
@ -39,6 +39,7 @@ void createTorchBackendToTosaBackendPipeline(
|
|||
void createTorchBackendToMhloBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||
|
|
|
@ -42,4 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
|
|||
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the mhlo backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||
|
|
|
@ -20,7 +20,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
Passes.cpp
|
||||
VerifyLinalgOnTensorsBackendContract.cpp
|
||||
VerifyTosaBackendContract.cpp
|
||||
|
||||
VerifyMhloBackendContract.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||
|
|
|
@ -141,5 +141,9 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
|||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
// Verify that we have lowered to the form that MHLO backends
|
||||
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
||||
// correct form.
|
||||
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
|
||||
namespace {
|
||||
class VerifyMhloBackendContractPass
|
||||
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
auto module = getOperation();
|
||||
TypeConverter converter;
|
||||
converter.addConversion([](Type type) -> Type {
|
||||
auto elemTy = type;
|
||||
if (isa<TensorType>(type)) {
|
||||
elemTy = type.cast<TensorType>().getElementType();
|
||||
}
|
||||
if (BaseMemRefType::isValidElementType(elemTy))
|
||||
return type;
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
||||
|
||||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||
opHasLegalTypes);
|
||||
// Basic scalar operations.
|
||||
target.addLegalDialect<mhlo::MhloDialect>();
|
||||
target.addLegalDialect<chlo::ChloDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||
// doesn't unnecessarily spew out the entire module.
|
||||
emitError(module.getLoc())
|
||||
<< "Module does not conform to the MHLO backend contract. "
|
||||
"See dialect conversion legality information above.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
|
||||
return std::make_unique<VerifyMhloBackendContractPass>();
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
|
@ -314,7 +314,10 @@ gentbl_cc_library(
|
|||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
(
|
||||
["-gen-pass-decls"],
|
||||
[
|
||||
"-gen-pass-decls",
|
||||
"-DTORCH_MLIR_ENABLE_MHLO",
|
||||
],
|
||||
"include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc",
|
||||
),
|
||||
],
|
||||
|
@ -502,6 +505,7 @@ cc_library(
|
|||
"lib/Dialect/TorchConversion/Transforms/Passes.cpp",
|
||||
"lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp",
|
||||
"lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp",
|
||||
"lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h",
|
||||
|
|
Loading…
Reference in New Issue