From f41695360019bde71d52ca7548944d5488779e12 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 24 Nov 2022 10:03:47 +0530 Subject: [PATCH] [MLIR][TORCH] Add TorchConversionToMLProgram and MLProgramBufferize pass This commit changes the `InsertRngGlobalsPass` to `TorchConversionToMLProgram` pass. This commit also adds the `MLProgramBufferize` pass for the bufferization of ml_program dialect ops to run on refbackend. Signed-Off By: Vivek Khandelwal --- include/torch-mlir/Conversion/Passes.td | 8 + .../TorchConversionToMLProgram.h | 23 +++ include/torch-mlir/RefBackend/Passes.h | 2 +- include/torch-mlir/RefBackend/Passes.td | 6 +- lib/Conversion/CMakeLists.txt | 2 + lib/Conversion/Passes.cpp | 1 + .../TorchConversionToMLProgram/CMakeLists.txt | 21 +++ .../TorchConversionToMLProgram.cpp | 125 +++++++++++++++ .../TorchConversion/Transforms/CMakeLists.txt | 1 + .../TorchConversion/Transforms/Passes.cpp | 2 + .../VerifyLinalgOnTensorsBackendContract.cpp | 3 + lib/RefBackend/RefBackend.cpp | 146 ++++++++++++------ .../linalg_on_tensors_backends/refbackend.py | 2 +- .../TorchConversionToMLProgram/basic.mlir | 19 +++ test/RefBackend/insert-rng-globals.mlir | 18 --- test/RefBackend/mlprogram-bufferize.mlir | 83 ++++++++++ 16 files changed, 394 insertions(+), 68 deletions(-) create mode 100644 include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h create mode 100644 lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt create mode 100644 lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp create mode 100644 test/Conversion/TorchConversionToMLProgram/basic.mlir delete mode 100644 test/RefBackend/insert-rng-globals.mlir create mode 100644 test/RefBackend/mlprogram-bufferize.mlir diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 28138edcb..7072b8d5f 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -125,6 +125,14 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; } +def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> { + let summary = "Convert recognized TorchConversion ops to MLProgram ops"; + let description = [{ + Convert TorchConversion ops to mlprogram ops. + }]; + let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; +} + #ifdef TORCH_MLIR_ENABLE_MHLO def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { let summary = "Convert Torch ops to MHLO ops"; diff --git a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h new file mode 100644 index 000000000..79d962492 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H +#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace torch { +std::unique_ptr> +createConvertTorchConversionToMLProgramPass(); +} +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index 0b749df5c..8f1b2b525 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -27,7 +27,7 @@ std::unique_ptr> createMungeCallingConventionsPass(); std::unique_ptr> createExpandOpsForLLVMPass(); -std::unique_ptr> createInsertRngGlobalsPass(); +std::unique_ptr> createMLProgramBufferizePass(); std::unique_ptr> createMungeMemrefCopyPass(); diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 518bc62f0..12d182e49 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -18,9 +18,9 @@ def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleO let dependentDialects = ["memref::MemRefDialect"]; } -def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> { - let summary = "Insert global variables and sequence to get the next global seed for RNG ops"; - let constructor = "mlir::torch::RefBackend::createInsertRngGlobalsPass();"; +def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> { + let summary = "Bufferize the MLProgram dialect ops"; + let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();"; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 29318e3b6..63a2337c8 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -6,6 +6,7 @@ if(TORCH_MLIR_ENABLE_MHLO) add_subdirectory(TorchToMhlo) endif() add_subdirectory(TorchToTMTensor) +add_subdirectory(TorchConversionToMLProgram) add_subdirectory(Utils) # TODO: Automate this with add_torch_mlir_conversion_library. @@ -14,6 +15,7 @@ set(linked_libs TorchMLIRTorchToLinalg TorchMLIRTorchToArith TorchMLIRTorchToTosa TorchMLIRTorchToTMTensor + TorchMLIRTorchConversionToMLProgram TorchMLIRConversionUtils) if(TORCH_MLIR_ENABLE_MHLO) list(APPEND linked_libs TorchMLIRTorchToMhlo) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index e899eb210..8d2117aa4 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -19,6 +19,7 @@ #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" //===----------------------------------------------------------------------===// // Pass registration diff --git a/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt b/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt new file mode 100644 index 000000000..f819ad018 --- /dev/null +++ b/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram + TorchConversionToMLProgram.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRLinalgDialect + MLIRMathDialect + TorchMLIRTorchDialect +) + +torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram) diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp new file mode 100644 index 000000000..60c126b06 --- /dev/null +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -0,0 +1,125 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::TorchConversion; + +static constexpr StringRef getSeedGobalVarName() { return "global_seed"; } + +// Declare a tensor global variable for the seed. +static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) { + b.setInsertionPointToStart(module.getBody()); + Type elemTy = b.getI64Type(); + auto tensorType = RankedTensorType::get({}, elemTy); + b.create( + UnknownLoc::get(b.getContext()), + /*sym_name=*/getSeedGobalVarName(), + /*type=*/tensorType, + /*is_mutable=*/true, + /*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}), + /*sym_visibility=*/b.getStringAttr("private")); +} + +namespace { +class ConvertGetNextSeedOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Generate sequence for getting the next seed with LCG step: + // nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64. + // Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator. + // Get the current seed value. + auto tensorType = RankedTensorType::get({}, rewriter.getI64Type()); + Value globalVar = rewriter.create( + loc, tensorType, + SymbolRefAttr::get(op->getContext(), getSeedGobalVarName())); + Value currentSeed = rewriter.create(loc, globalVar); + + // The value of multiplier and incrementStep are referenced from + // https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64. + Value multiplier = rewriter.create( + loc, rewriter.getI64IntegerAttr(6364136223846793005)); + Value incrementStep = rewriter.create( + loc, rewriter.getI64IntegerAttr(1442695040888963407)); + // temp = multiplier * currentSeed + incrementStep + Value mul = rewriter.create(loc, currentSeed, multiplier); + Value seed = rewriter.create(loc, mul, incrementStep); + globalVar = rewriter.create(loc, seed, globalVar); + rewriter.create( + loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), + globalVar); + rewriter.replaceOp(op, seed); + return success(); + } +}; +} // namespace + +// ----------------------------------------------------------------------------- +// The pass +// ----------------------------------------------------------------------------- + +namespace { +class ConvertTorchConversionToMLProgram + : public ConvertTorchConversionToMLProgramBase< + ConvertTorchConversionToMLProgram> { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + auto module = getOperation()->getParentOfType(); + OpBuilder b(module.getBodyRegion()); + createGlobalVariableForSeed(b, module); + + RewritePatternSet patterns(context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchConversionToMLProgramPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 5412c9ffa..eaa15b00e 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ set(LinkedLibs MLIRIR TorchMLIRTorchToTMTensor TorchMLIRTorchToArith TorchMLIRTorchToSCF + TorchMLIRTorchConversionToMLProgram MLIRMemRefTransforms) if(TORCH_MLIR_ENABLE_MHLO) diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 30054c13a..ffffce244 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -20,6 +20,7 @@ #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #ifdef TORCH_MLIR_ENABLE_MHLO #include "mhlo/transforms/passes.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" @@ -71,6 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); + pm.addNestedPass(createConvertTorchConversionToMLProgramPass()); pm.addNestedPass(memref::createExpandOpsPass()); // Clean up any non-canonical code introduced above.. diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index cae5586b8..00117d895 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -80,6 +81,8 @@ class VerifyLinalgOnTensorsBackendContractPass target.addDynamicallyLegalDialect(opHasLegalTypes); target.addDynamicallyLegalDialect(opHasLegalTypes); target.addDynamicallyLegalDialect(opHasLegalTypes); + target.addDynamicallyLegalDialect( + opHasLegalTypes); // ConstantOp is used for tensors and for scalars. target.addDynamicallyLegalOp(opHasLegalTypes); diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index bcd8d6030..9ea0fdecf 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -16,12 +16,14 @@ #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -67,7 +69,6 @@ static bool isArgMemRefTypeValid(Type type) { return true; if (integerTy.isSignlessInteger(1)) return true; - } } return false; @@ -219,74 +220,129 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() { } //===----------------------------------------------------------------------===// -// InsertRngGlobals +// MLProgramBufferize //===----------------------------------------------------------------------===// -static constexpr StringRef getSeedGobalVarName() { return "global_seed"; } +static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp, + OpBuilder &b) { + if (!globalOp.getValue().has_value()) + return globalOp.emitError("global op must have a value"); -// Declare a memref global variable for the seed. -static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) { - b.setInsertionPointToStart(module.getBody()); - Type elemTy = b.getI64Type(); - auto memref0D = MemRefType::get({}, elemTy); - auto tensor0D = RankedTensorType::get({}, elemTy); + RankedTensorType tensorType = globalOp.getType().cast(); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + b.setInsertionPointToStart(globalOp->getParentOfType().getBody()); b.create( - UnknownLoc::get(b.getContext()), getSeedGobalVarName(), - /*sym_visibility=*/b.getStringAttr("private"), - /*type=*/memref0D, - /*initial_value=*/DenseIntElementsAttr::get(tensor0D, {APInt(64, 0)}), - /*constant=*/false, + UnknownLoc::get(b.getContext()), globalOp.getSymName(), + /*sym_visibility=*/globalOp.getSymVisibilityAttr(), + /*type=*/memrefType, + /*initial_value=*/globalOp.getValue().value(), + /*constant=*/globalOp.getIsMutable() ? false : true, /*alignment=*/nullptr); + return success(); } -// Generate sequence for getting the next seed with LCG step: -// nextSeed = (multiplier * currentSeed + incrementStep) mod 64. -// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator. -static Value lowerGetNextSeed(OpBuilder &b, Location loc) { - // Get the current seed value. - auto memref1DType = MemRefType::get({}, b.getI64Type()); - Value globalVar = - b.create(loc, memref1DType, getSeedGobalVarName()); - Value currentSeed = b.create(loc, globalVar); +static LogicalResult +bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp, + OpBuilder &b, SmallVector &toErase) { + RankedTensorType tensorType = globalLoadOp.getType().cast(); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - // The value of multiplier and incrementStep are referenced from - // https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64. - Value multiplier = b.create( - loc, b.getI64IntegerAttr(6364136223846793005)); - Value incrementStep = b.create( - loc, b.getI64IntegerAttr(1442695040888963407)); - // temp = multiplier * currentSeed + incrementStep - Value mul = b.create(loc, currentSeed, multiplier); - Value nextSeed = b.create(loc, mul, incrementStep); - b.create(loc, nextSeed, globalVar); - return nextSeed; + b.setInsertionPoint(globalLoadOp); + Value globalVal = b.create( + globalLoadOp.getLoc(), memrefType, + globalLoadOp.getGlobalAttr().getLeafReference()); + globalVal = b.create(globalLoadOp->getLoc(), + tensorType, globalVal); + globalLoadOp->getResult(0).replaceAllUsesWith(globalVal); + return success(); +} + +static LogicalResult +bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp, + OpBuilder &b, + SmallVector &toErase) { + RankedTensorType tensorType = + globalStoreOp.getValue().getType().cast(); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + + b.setInsertionPoint(globalStoreOp); + Value memref = b.create( + globalStoreOp.getLoc(), memrefType, + globalStoreOp.getGlobalAttr().getLeafReference()); + Value copyValue = b.create( + globalStoreOp->getLoc(), memrefType, globalStoreOp.getValue()); + b.create(globalStoreOp->getLoc(), copyValue, memref); + return success(); } -// The global seed is stored into a memref global variable as the only -// element. namespace { -class InsertRngGlobals : public InsertRngGlobalsBase { +/// Converts MLProgram operations that work on tensor-type operands or results +/// to work on buffers. +class MLProgramBufferize : public MLProgramBufferizeBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + void runOnOperation() override { auto module = getOperation(); OpBuilder b(module.getBodyRegion()); - createGlobalVariableForSeed(b, module); SmallVector toErase; - module.walk([&](TorchConversion::GetNextSeedOp op) { - b.setInsertionPoint(op); - Value seed = lowerGetNextSeed(b, op.getLoc()); - op.replaceAllUsesWith(seed); + + auto walkResult = module.walk([&](ml_program::GlobalOp op) { + if (auto type = op.getType().dyn_cast()) { + if (!type.hasStaticShape()) { + // If the ml_program.global has dynamically shaped tensor. + op.emitError( + "unimplemented: global op bufferization with dynamic shape"); + return WalkResult::interrupt(); + } + } else { + // If the ml_program.global is of non-tensor type. + op.emitError("unsupported global op type"); + return WalkResult::interrupt(); + } + + if (failed(bufferizeMLProgramGlobalOp(op, b))) { + op.emitError("bufferization for this op failed"); + return WalkResult::interrupt(); + } + toErase.push_back(op); + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return signalPassFailure(); + + module.walk([&](ml_program::GlobalLoadOp op) { + if (failed(bufferizeMLProgramGlobaLoadOp(op, b, toErase))) { + op.emitError("bufferization for this op failed"); + return; + } toErase.push_back(op); }); - for (auto op : toErase) + module.walk([&](ml_program::GlobalStoreOp op) { + if (failed(bufferizeMLProgramGlobaStoreOp(op, b, toErase))) { + op.emitError("bufferization for this op failed"); + return; + } + toErase.push_back(op); + }); + + for (auto op : llvm::reverse(toErase)) op->erase(); } }; } // namespace std::unique_ptr> -mlir::torch::RefBackend::createInsertRngGlobalsPass() { - return std::make_unique(); +mlir::torch::RefBackend::createMLProgramBufferizePass() { + return std::make_unique(); } //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index d3c5788f0..fa40a3d7e 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -123,6 +123,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([ "func.func(linalg-bufferize)", "func-bufferize", "arith-bufferize", + "refback-mlprogram-bufferize", "func.func(tensor-bufferize)", "func.func(finalizing-bufferize)", # Munge to make it ExecutionEngine compatible. @@ -134,7 +135,6 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([ "refback-munge-calling-conventions", # Insert global variable and instruction sequence for getting the next # global seed used in stateful rng. - "refback-insert-rng-globals", # Lower to LLVM "func.func(tm-tensor-to-loops)", "func.func(refback-munge-memref-copy)", diff --git a/test/Conversion/TorchConversionToMLProgram/basic.mlir b/test/Conversion/TorchConversionToMLProgram/basic.mlir new file mode 100644 index 000000000..cc58ad3ac --- /dev/null +++ b/test/Conversion/TorchConversionToMLProgram/basic.mlir @@ -0,0 +1,19 @@ +// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s + +// CHECK-LABEL: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor +// CHECK-LABEL: func.func @f() -> i64 { +// CHECK: %[[GLOBAL:.*]] = ml_program.global_load @global_seed : tensor +// CHECK: %[[SEED:.*]] = tensor.extract %[[GLOBAL]][] : tensor +// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64 +// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64 +// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64 +// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64 +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[GLOBAL]][] : tensor +// CHECK: ml_program.global_store @global_seed = %[[INSERTED]] : tensor +// CHECK: return %2 : i64 +module { + func.func @f() -> i64 { + %seed = torch_c.get_next_seed : () -> i64 + return %seed : i64 + } +} diff --git a/test/RefBackend/insert-rng-globals.mlir b/test/RefBackend/insert-rng-globals.mlir deleted file mode 100644 index 51d836ee0..000000000 --- a/test/RefBackend/insert-rng-globals.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s - -// CHECK-LABEL: memref.global "private" @global_seed : memref = dense<0> -// CHECK-LABEL: func.func @f() -> i64 { -// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref -// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref -// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64 -// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64 -// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64 -// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64 -// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref -// CHECK: return %[[NEXT_SEED]] : i64 -module { - func.func @f() -> i64 { - %seed = torch_c.get_next_seed : () -> i64 - return %seed : i64 - } -} diff --git a/test/RefBackend/mlprogram-bufferize.mlir b/test/RefBackend/mlprogram-bufferize.mlir new file mode 100644 index 000000000..bd8c2a6c0 --- /dev/null +++ b/test/RefBackend/mlprogram-bufferize.mlir @@ -0,0 +1,83 @@ +// RUN: torch-mlir-opt %s -refback-mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: memref.global "private" @global_seed : memref = dense<0> +// CHECK-LABEL: func.func @forward() -> i64 { +// CHECK: %[[CST127:.*]] = arith.constant 127 : i64 +// CHECK: %[[GLOBAL_SEED:.*]] = memref.get_global @global_seed : memref +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref +// CHECK: %[[SEED:.*]] = tensor.extract %[[TENSOR]][] : tensor +// CHECK: %[[NEXT_SEED:.*]] = arith.muli %[[SEED]], %[[CST127]] : i64 +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[TENSOR]][] : tensor +// CHECK: %[[GLOBAL_SEED_1:.*]] = memref.get_global @global_seed : memref +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : memref +// CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_SEED_1]] : memref to memref +// CHECK: return %[[NEXT_SEED]] : i64 +module { + ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor + func.func @forward() -> i64 { + %c127_i64 = arith.constant 127 : i64 + %0 = ml_program.global_load @global_seed : tensor + %extracted = tensor.extract %0[] : tensor + %1 = arith.muli %extracted, %c127_i64 : i64 + %inserted = tensor.insert %1 into %0[] : tensor + ml_program.global_store @global_seed = %inserted : tensor + return %1 : i64 + } +} + +// ----- + +module { + // expected-error @below {{unsupported global op type}} + ml_program.global private mutable @global_seed(0 : i64) : i64 + func.func @forward() -> i64 { + %c127_i64 = arith.constant 127 : i64 + %0 = ml_program.global_load @global_seed : i64 + %1 = arith.muli %0, %c127_i64 : i64 + ml_program.global_store @global_seed = %1 : i64 + return %1 : i64 + } +} + +// ----- + +module { + // expected-error @below {{unsupported global op type}} + ml_program.global private mutable @global_seed(dense<0> : memref) : memref + func.func @forward() -> i64 { + %c127_i64 = arith.constant 127 : i64 + %0 = ml_program.global_load @global_seed : memref + %extracted = memref.load %0[] : memref + %1 = arith.muli %extracted, %c127_i64 : i64 + memref.store %1, %0[] : memref + ml_program.global_store @global_seed = %0 : memref + return %1 : i64 + } +} + +// ----- + +module { + // expected-error @below {{invalid tensor element type}} + ml_program.global private mutable @global_seed(dense<0> : tensor>) : tensor> + func.func @forward() -> i64 { + %c127_i64 = arith.constant 127 : i64 + return %c127_i64 : i64 + } +} + +// ----- +module { + // expected-error @below {{unimplemented: global op bufferization with dynamic shape}} + ml_program.global private mutable @global_seed(dense<0> : tensor<1xi64>) : tensor + func.func @forward() -> i64 { + %c127_i64 = arith.constant 127 : i64 + %c0 = arith.constant 0 : index + %0 = ml_program.global_load @global_seed : tensor + %extracted = tensor.extract %0[%c0] : tensor + %1 = arith.muli %extracted, %c127_i64 : i64 + %inserted = tensor.insert %1 into %0[%c0] : tensor + ml_program.global_store @global_seed = %inserted : tensor + return %1 : i64 + } +}