Minor fixes for `ConvertTorchConversionToMLProgram`. (#1991)

* Only create the global seed variable if it does not exist already.
* Make the pass a module pass. A func pass may not modify its parent op.
pull/2001/head
Alexandre Rames 2023-04-04 09:09:58 -07:00 committed by GitHub
parent e7d4771403
commit d24fa71368
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 12 deletions

View File

@ -125,7 +125,7 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
}
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> {
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
let description = [{
Convert TorchConversion ops to mlprogram ops.

View File

@ -10,12 +10,12 @@
#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTorchConversionToMLProgramPass();
}
} // namespace mlir

View File

@ -11,6 +11,7 @@
#define TORCHMLIR_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {

View File

@ -28,10 +28,22 @@ using namespace mlir::torch::TorchConversion;
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
// Declare a tensor<i64> global variable for the seed.
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
b.setInsertionPointToStart(module.getBody());
static LogicalResult getOrCreateGlobalVariableForSeed(OpBuilder &b,
ModuleOp module) {
auto globalSeedSymbol =
SymbolTable::lookupSymbolIn(module, getSeedGobalVarName());
Type elemTy = b.getI64Type();
auto tensorType = RankedTensorType::get({}, elemTy);
if (globalSeedSymbol) {
auto globalSeed = dyn_cast<ml_program::GlobalOp>(globalSeedSymbol);
if (!globalSeed || globalSeed.getType() != tensorType)
return module.emitError("Unexpected type for global seed.");
return success();
}
b.setInsertionPointToStart(module.getBody());
b.create<ml_program::GlobalOp>(
UnknownLoc::get(b.getContext()),
/*sym_name=*/getSeedGobalVarName(),
@ -39,6 +51,8 @@ static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
/*is_mutable=*/true,
/*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}),
/*sym_visibility=*/b.getStringAttr("private"));
return success();
}
namespace {
@ -104,22 +118,27 @@ public:
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
auto module = getOperation()->getParentOfType<ModuleOp>();
auto module = getOperation();
OpBuilder b(module.getBodyRegion());
createGlobalVariableForSeed(b, module);
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
signalPassFailure();
RewritePatternSet patterns(context);
target.addIllegalOp<GetNextSeedOp>();
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
getOperation()->walk(
[this, &target, &frozenPatterns](func::FuncOp function) {
if (failed(applyPartialConversion(function, target, frozenPatterns)))
return signalPassFailure();
});
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::createConvertTorchConversionToMLProgramPass() {
return std::make_unique<ConvertTorchConversionToMLProgram>();
}

View File

@ -72,7 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchConversionToMLProgramPass());
pm.addPass(createConvertTorchConversionToMLProgramPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
// Clean up any non-canonical code introduced above..

View File

@ -0,0 +1,15 @@
// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s
module {
func.func private @f0() -> i64
func.func private @f1() -> i64
func.func private @f2() -> i64
func.func private @f3() -> i64
func.func private @f4() -> i64
func.func private @f5() -> i64
func.func private @f6() -> i64
func.func private @f7() -> i64
}
// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-NOT: @global_seed