diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index b5f30bfbe..3a130f472 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -125,7 +125,7 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; } -def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> { +def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> { let summary = "Convert recognized TorchConversion ops to MLProgram ops"; let description = [{ Convert TorchConversion ops to mlprogram ops. diff --git a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h index 79d962492..6d14ec927 100644 --- a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h +++ b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h @@ -10,12 +10,12 @@ #ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H #define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace torch { -std::unique_ptr> +std::unique_ptr> createConvertTorchConversionToMLProgramPass(); } } // namespace mlir diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index 2e98b37d4..aa832141f 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -11,6 +11,7 @@ #define TORCHMLIR_CONVERSION_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index 839bae364..eab81c2be 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -28,10 +28,22 @@ using namespace mlir::torch::TorchConversion; static constexpr StringRef getSeedGobalVarName() { return "global_seed"; } // Declare a tensor global variable for the seed. -static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) { - b.setInsertionPointToStart(module.getBody()); +static LogicalResult getOrCreateGlobalVariableForSeed(OpBuilder &b, + ModuleOp module) { + auto globalSeedSymbol = + SymbolTable::lookupSymbolIn(module, getSeedGobalVarName()); + Type elemTy = b.getI64Type(); auto tensorType = RankedTensorType::get({}, elemTy); + + if (globalSeedSymbol) { + auto globalSeed = dyn_cast(globalSeedSymbol); + if (!globalSeed || globalSeed.getType() != tensorType) + return module.emitError("Unexpected type for global seed."); + return success(); + } + + b.setInsertionPointToStart(module.getBody()); b.create( UnknownLoc::get(b.getContext()), /*sym_name=*/getSeedGobalVarName(), @@ -39,6 +51,8 @@ static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) { /*is_mutable=*/true, /*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}), /*sym_visibility=*/b.getStringAttr("private")); + + return success(); } namespace { @@ -104,22 +118,27 @@ public: typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - auto module = getOperation()->getParentOfType(); + auto module = getOperation(); OpBuilder b(module.getBodyRegion()); - createGlobalVariableForSeed(b, module); + if (failed(getOrCreateGlobalVariableForSeed(b, module))) + signalPassFailure(); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + getOperation()->walk( + [this, &target, &frozenPatterns](func::FuncOp function) { + if (failed(applyPartialConversion(function, target, frozenPatterns))) + return signalPassFailure(); + }); } }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::createConvertTorchConversionToMLProgramPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 14d8f360b..51d917329 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -72,7 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); - pm.addNestedPass(createConvertTorchConversionToMLProgramPass()); + pm.addPass(createConvertTorchConversionToMLProgramPass()); pm.addNestedPass(memref::createExpandOpsPass()); // Clean up any non-canonical code introduced above.. diff --git a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir new file mode 100644 index 000000000..8ef04d951 --- /dev/null +++ b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir @@ -0,0 +1,15 @@ +// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s + +module { + func.func private @f0() -> i64 + func.func private @f1() -> i64 + func.func private @f2() -> i64 + func.func private @f3() -> i64 + func.func private @f4() -> i64 + func.func private @f5() -> i64 + func.func private @f6() -> i64 + func.func private @f7() -> i64 +} + +// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor +// CHECK-NOT: @global_seed