[MHLO] add new options to pipeline (#1331)

pull/1363/head
Tanyo Kwok 2022-09-13 01:27:41 +08:00 committed by GitHub
parent 71b1f0dd7a
commit 7f63a17a46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 5 deletions

View File

@ -36,9 +36,22 @@ void createTorchBackendToTosaBackendPipeline(
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
struct MhloBackendPipelineOptions
: public PassPipelineOptions<MhloBackendPipelineOptions> {
Option<bool> enableStaticShape{
*this, "enable-static-shape",
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<bool> enableI32Index{
*this, "enable-i32-index",
llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"),
llvm::cl::init(false)};
};
void createTorchBackendToMhloBackendPipeline(
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);
OpPassManager &pm, const MhloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
#endif

View File

@ -53,7 +53,7 @@ void mlir::torch::registerTorchConversionPasses() {
"contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline);
#ifdef TORCH_MLIR_ENABLE_MHLO
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
"torch-backend-to-mhlo-backend-pipeline",
"Pipeline lowering torch backend contract to MHLO backend "
"contract.",
@ -121,8 +121,10 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
#ifdef TORCH_MLIR_ENABLE_MHLO
void TorchConversion::createTorchBackendToMhloBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
OpPassManager &pm,
const TorchConversion::MhloBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
options.enableStaticShape, options.enableI32Index));
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());