mirror of https://github.com/llvm/torch-mlir
Minor fixes for `ConvertTorchConversionToMLProgram`. (#1991)
* Only create the global seed variable if it does not exist already. * Make the pass a module pass. A func pass may not modify its parent op.pull/2001/head
parent
e7d4771403
commit
d24fa71368
|
@ -125,7 +125,7 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
|||
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> {
|
||||
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
|
||||
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
|
||||
let description = [{
|
||||
Convert TorchConversion ops to mlprogram ops.
|
||||
|
|
|
@ -10,12 +10,12 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTorchConversionToMLProgramPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#define TORCHMLIR_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -28,10 +28,22 @@ using namespace mlir::torch::TorchConversion;
|
|||
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
|
||||
|
||||
// Declare a tensor<i64> global variable for the seed.
|
||||
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
||||
b.setInsertionPointToStart(module.getBody());
|
||||
static LogicalResult getOrCreateGlobalVariableForSeed(OpBuilder &b,
|
||||
ModuleOp module) {
|
||||
auto globalSeedSymbol =
|
||||
SymbolTable::lookupSymbolIn(module, getSeedGobalVarName());
|
||||
|
||||
Type elemTy = b.getI64Type();
|
||||
auto tensorType = RankedTensorType::get({}, elemTy);
|
||||
|
||||
if (globalSeedSymbol) {
|
||||
auto globalSeed = dyn_cast<ml_program::GlobalOp>(globalSeedSymbol);
|
||||
if (!globalSeed || globalSeed.getType() != tensorType)
|
||||
return module.emitError("Unexpected type for global seed.");
|
||||
return success();
|
||||
}
|
||||
|
||||
b.setInsertionPointToStart(module.getBody());
|
||||
b.create<ml_program::GlobalOp>(
|
||||
UnknownLoc::get(b.getContext()),
|
||||
/*sym_name=*/getSeedGobalVarName(),
|
||||
|
@ -39,6 +51,8 @@ static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
|||
/*is_mutable=*/true,
|
||||
/*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}),
|
||||
/*sym_visibility=*/b.getStringAttr("private"));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -104,22 +118,27 @@ public:
|
|||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
auto module = getOperation()->getParentOfType<ModuleOp>();
|
||||
auto module = getOperation();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
createGlobalVariableForSeed(b, module);
|
||||
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
|
||||
signalPassFailure();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<GetNextSeedOp>();
|
||||
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
|
||||
getOperation()->walk(
|
||||
[this, &target, &frozenPatterns](func::FuncOp function) {
|
||||
if (failed(applyPartialConversion(function, target, frozenPatterns)))
|
||||
return signalPassFailure();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::createConvertTorchConversionToMLProgramPass() {
|
||||
return std::make_unique<ConvertTorchConversionToMLProgram>();
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchConversionToMLProgramPass());
|
||||
pm.addPass(createConvertTorchConversionToMLProgramPass());
|
||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||
|
||||
// Clean up any non-canonical code introduced above..
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s
|
||||
|
||||
module {
|
||||
func.func private @f0() -> i64
|
||||
func.func private @f1() -> i64
|
||||
func.func private @f2() -> i64
|
||||
func.func private @f3() -> i64
|
||||
func.func private @f4() -> i64
|
||||
func.func private @f5() -> i64
|
||||
func.func private @f6() -> i64
|
||||
func.func private @f7() -> i64
|
||||
}
|
||||
|
||||
// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
|
||||
// CHECK-NOT: @global_seed
|
Loading…
Reference in New Issue