[MLIR][TORCH] Add TorchConversionToMLProgram and MLProgramBufferize pass

This commit changes the `InsertRngGlobalsPass` to `TorchConversionToMLProgram`
pass. This commit also adds the `MLProgramBufferize` pass for the
bufferization of ml_program dialect ops to run on refbackend.

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1672/head snapshot-20221202.675
Vivek Khandelwal 2022-11-24 10:03:47 +05:30
parent 3fc27cf6ca
commit f416953600
16 changed files with 394 additions and 68 deletions

View File

@ -125,6 +125,14 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
} }
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> {
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
let description = [{
Convert TorchConversion ops to mlprogram ops.
}];
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
}
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_MHLO
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops"; let summary = "Convert Torch ops to MHLO ops";

View File

@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchConversionToMLProgramPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H

View File

@ -27,7 +27,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass(); std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass(); std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass(); std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();

View File

@ -18,9 +18,9 @@ def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleO
let dependentDialects = ["memref::MemRefDialect"]; let dependentDialects = ["memref::MemRefDialect"];
} }
def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> { def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> {
let summary = "Insert global variables and sequence to get the next global seed for RNG ops"; let summary = "Bufferize the MLProgram dialect ops";
let constructor = "mlir::torch::RefBackend::createInsertRngGlobalsPass();"; let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();";
let dependentDialects = ["memref::MemRefDialect"]; let dependentDialects = ["memref::MemRefDialect"];
} }

View File

@ -6,6 +6,7 @@ if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToMhlo) add_subdirectory(TorchToMhlo)
endif() endif()
add_subdirectory(TorchToTMTensor) add_subdirectory(TorchToTMTensor)
add_subdirectory(TorchConversionToMLProgram)
add_subdirectory(Utils) add_subdirectory(Utils)
# TODO: Automate this with add_torch_mlir_conversion_library. # TODO: Automate this with add_torch_mlir_conversion_library.
@ -14,6 +15,7 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToArith TorchMLIRTorchToArith
TorchMLIRTorchToTosa TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils) TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs TorchMLIRTorchToMhlo) list(APPEND linked_libs TorchMLIRTorchToMhlo)

View File

@ -19,6 +19,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration

View File

@ -0,0 +1,21 @@
add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram
TorchConversionToMLProgram.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRLinalgDialect
MLIRMathDialect
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram)

View File

@ -0,0 +1,125 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
// Declare a tensor<i64> global variable for the seed.
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
b.setInsertionPointToStart(module.getBody());
Type elemTy = b.getI64Type();
auto tensorType = RankedTensorType::get({}, elemTy);
b.create<ml_program::GlobalOp>(
UnknownLoc::get(b.getContext()),
/*sym_name=*/getSeedGobalVarName(),
/*type=*/tensorType,
/*is_mutable=*/true,
/*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}),
/*sym_visibility=*/b.getStringAttr("private"));
}
namespace {
class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Generate sequence for getting the next seed with LCG step:
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
// Get the current seed value.
auto tensorType = RankedTensorType::get({}, rewriter.getI64Type());
Value globalVar = rewriter.create<ml_program::GlobalLoadOp>(
loc, tensorType,
SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()));
Value currentSeed = rewriter.create<tensor::ExtractOp>(loc, globalVar);
// The value of multiplier and incrementStep are referenced from
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
Value multiplier = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(6364136223846793005));
Value incrementStep = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(1442695040888963407));
// temp = multiplier * currentSeed + incrementStep
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar);
rewriter.create<ml_program::GlobalStoreOp>(
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
globalVar);
rewriter.replaceOp(op, seed);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
namespace {
class ConvertTorchConversionToMLProgram
: public ConvertTorchConversionToMLProgramBase<
ConvertTorchConversionToMLProgram> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<ml_program::MLProgramDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tensor::TensorDialect, arith::ArithDialect,
ml_program::MLProgramDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
auto module = getOperation()->getParentOfType<ModuleOp>();
OpBuilder b(module.getBodyRegion());
createGlobalVariableForSeed(b, module);
RewritePatternSet patterns(context);
target.addIllegalOp<GetNextSeedOp>();
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchConversionToMLProgramPass() {
return std::make_unique<ConvertTorchConversionToMLProgram>();
}

View File

@ -8,6 +8,7 @@ set(LinkedLibs MLIRIR
TorchMLIRTorchToTMTensor TorchMLIRTorchToTMTensor
TorchMLIRTorchToArith TorchMLIRTorchToArith
TorchMLIRTorchToSCF TorchMLIRTorchToSCF
TorchMLIRTorchConversionToMLProgram
MLIRMemRefTransforms) MLIRMemRefTransforms)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_MHLO)

View File

@ -20,6 +20,7 @@
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_MHLO
#include "mhlo/transforms/passes.h" #include "mhlo/transforms/passes.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
@ -71,6 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchConversionToMLProgramPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass()); pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
// Clean up any non-canonical code introduced above.. // Clean up any non-canonical code introduced above..

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -80,6 +81,8 @@ class VerifyLinalgOnTensorsBackendContractPass
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes); target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes); target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes); target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
opHasLegalTypes);
// ConstantOp is used for tensors and for scalars. // ConstantOp is used for tensors and for scalars.
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes); target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);

View File

@ -16,12 +16,14 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
@ -67,7 +69,6 @@ static bool isArgMemRefTypeValid(Type type) {
return true; return true;
if (integerTy.isSignlessInteger(1)) if (integerTy.isSignlessInteger(1))
return true; return true;
} }
} }
return false; return false;
@ -219,74 +220,129 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// InsertRngGlobals // MLProgramBufferize
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; } static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
OpBuilder &b) {
if (!globalOp.getValue().has_value())
return globalOp.emitError("global op must have a value");
// Declare a memref<i64> global variable for the seed. RankedTensorType tensorType = globalOp.getType().cast<RankedTensorType>();
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) { MemRefType memrefType =
b.setInsertionPointToStart(module.getBody()); MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Type elemTy = b.getI64Type();
auto memref0D = MemRefType::get({}, elemTy); b.setInsertionPointToStart(globalOp->getParentOfType<ModuleOp>().getBody());
auto tensor0D = RankedTensorType::get({}, elemTy);
b.create<memref::GlobalOp>( b.create<memref::GlobalOp>(
UnknownLoc::get(b.getContext()), getSeedGobalVarName(), UnknownLoc::get(b.getContext()), globalOp.getSymName(),
/*sym_visibility=*/b.getStringAttr("private"), /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
/*type=*/memref0D, /*type=*/memrefType,
/*initial_value=*/DenseIntElementsAttr::get(tensor0D, {APInt(64, 0)}), /*initial_value=*/globalOp.getValue().value(),
/*constant=*/false, /*constant=*/globalOp.getIsMutable() ? false : true,
/*alignment=*/nullptr); /*alignment=*/nullptr);
return success();
} }
// Generate sequence for getting the next seed with LCG step: static LogicalResult
// nextSeed = (multiplier * currentSeed + incrementStep) mod 64. bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp,
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator. OpBuilder &b, SmallVector<Operation *> &toErase) {
static Value lowerGetNextSeed(OpBuilder &b, Location loc) { RankedTensorType tensorType = globalLoadOp.getType().cast<RankedTensorType>();
// Get the current seed value. MemRefType memrefType =
auto memref1DType = MemRefType::get({}, b.getI64Type()); MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value globalVar =
b.create<memref::GetGlobalOp>(loc, memref1DType, getSeedGobalVarName());
Value currentSeed = b.create<memref::LoadOp>(loc, globalVar);
// The value of multiplier and incrementStep are referenced from b.setInsertionPoint(globalLoadOp);
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64. Value globalVal = b.create<memref::GetGlobalOp>(
Value multiplier = b.create<arith::ConstantOp>( globalLoadOp.getLoc(), memrefType,
loc, b.getI64IntegerAttr(6364136223846793005)); globalLoadOp.getGlobalAttr().getLeafReference());
Value incrementStep = b.create<arith::ConstantOp>( globalVal = b.create<bufferization::ToTensorOp>(globalLoadOp->getLoc(),
loc, b.getI64IntegerAttr(1442695040888963407)); tensorType, globalVal);
// temp = multiplier * currentSeed + incrementStep globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier); return success();
Value nextSeed = b.create<arith::AddIOp>(loc, mul, incrementStep); }
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
return nextSeed; static LogicalResult
bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp,
OpBuilder &b,
SmallVector<Operation *> &toErase) {
RankedTensorType tensorType =
globalStoreOp.getValue().getType().cast<RankedTensorType>();
MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
b.setInsertionPoint(globalStoreOp);
Value memref = b.create<memref::GetGlobalOp>(
globalStoreOp.getLoc(), memrefType,
globalStoreOp.getGlobalAttr().getLeafReference());
Value copyValue = b.create<bufferization::ToMemrefOp>(
globalStoreOp->getLoc(), memrefType, globalStoreOp.getValue());
b.create<memref::CopyOp>(globalStoreOp->getLoc(), copyValue, memref);
return success();
} }
// The global seed is stored into a memref<i64> global variable as the only
// element.
namespace { namespace {
class InsertRngGlobals : public InsertRngGlobalsBase<InsertRngGlobals> { /// Converts MLProgram operations that work on tensor-type operands or results
/// to work on buffers.
class MLProgramBufferize : public MLProgramBufferizeBase<MLProgramBufferize> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
}
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
OpBuilder b(module.getBodyRegion()); OpBuilder b(module.getBodyRegion());
createGlobalVariableForSeed(b, module);
SmallVector<Operation *> toErase; SmallVector<Operation *> toErase;
module.walk([&](TorchConversion::GetNextSeedOp op) {
b.setInsertionPoint(op); auto walkResult = module.walk([&](ml_program::GlobalOp op) {
Value seed = lowerGetNextSeed(b, op.getLoc()); if (auto type = op.getType().dyn_cast<RankedTensorType>()) {
op.replaceAllUsesWith(seed); if (!type.hasStaticShape()) {
// If the ml_program.global has dynamically shaped tensor.
op.emitError(
"unimplemented: global op bufferization with dynamic shape");
return WalkResult::interrupt();
}
} else {
// If the ml_program.global is of non-tensor type.
op.emitError("unsupported global op type");
return WalkResult::interrupt();
}
if (failed(bufferizeMLProgramGlobalOp(op, b))) {
op.emitError("bufferization for this op failed");
return WalkResult::interrupt();
}
toErase.push_back(op);
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return signalPassFailure();
module.walk([&](ml_program::GlobalLoadOp op) {
if (failed(bufferizeMLProgramGlobaLoadOp(op, b, toErase))) {
op.emitError("bufferization for this op failed");
return;
}
toErase.push_back(op); toErase.push_back(op);
}); });
for (auto op : toErase) module.walk([&](ml_program::GlobalStoreOp op) {
if (failed(bufferizeMLProgramGlobaStoreOp(op, b, toErase))) {
op.emitError("bufferization for this op failed");
return;
}
toErase.push_back(op);
});
for (auto op : llvm::reverse(toErase))
op->erase(); op->erase();
} }
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::RefBackend::createInsertRngGlobalsPass() { mlir::torch::RefBackend::createMLProgramBufferizePass() {
return std::make_unique<InsertRngGlobals>(); return std::make_unique<MLProgramBufferize>();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -123,6 +123,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
"func.func(linalg-bufferize)", "func.func(linalg-bufferize)",
"func-bufferize", "func-bufferize",
"arith-bufferize", "arith-bufferize",
"refback-mlprogram-bufferize",
"func.func(tensor-bufferize)", "func.func(tensor-bufferize)",
"func.func(finalizing-bufferize)", "func.func(finalizing-bufferize)",
# Munge to make it ExecutionEngine compatible. # Munge to make it ExecutionEngine compatible.
@ -134,7 +135,6 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
"refback-munge-calling-conventions", "refback-munge-calling-conventions",
# Insert global variable and instruction sequence for getting the next # Insert global variable and instruction sequence for getting the next
# global seed used in stateful rng. # global seed used in stateful rng.
"refback-insert-rng-globals",
# Lower to LLVM # Lower to LLVM
"func.func(tm-tensor-to-loops)", "func.func(tm-tensor-to-loops)",
"func.func(refback-munge-memref-copy)", "func.func(refback-munge-memref-copy)",

View File

@ -0,0 +1,19 @@
// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s
// CHECK-LABEL: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-LABEL: func.func @f() -> i64 {
// CHECK: %[[GLOBAL:.*]] = ml_program.global_load @global_seed : tensor<i64>
// CHECK: %[[SEED:.*]] = tensor.extract %[[GLOBAL]][] : tensor<i64>
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[GLOBAL]][] : tensor<i64>
// CHECK: ml_program.global_store @global_seed = %[[INSERTED]] : tensor<i64>
// CHECK: return %2 : i64
module {
func.func @f() -> i64 {
%seed = torch_c.get_next_seed : () -> i64
return %seed : i64
}
}

View File

@ -1,18 +0,0 @@
// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
// CHECK-LABEL: func.func @f() -> i64 {
// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64>
// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64>
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
// CHECK: return %[[NEXT_SEED]] : i64
module {
func.func @f() -> i64 {
%seed = torch_c.get_next_seed : () -> i64
return %seed : i64
}
}

View File

@ -0,0 +1,83 @@
// RUN: torch-mlir-opt %s -refback-mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
// CHECK-LABEL: func.func @forward() -> i64 {
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
// CHECK: %[[GLOBAL_SEED:.*]] = memref.get_global @global_seed : memref<i64>
// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref<i64>
// CHECK: %[[SEED:.*]] = tensor.extract %[[TENSOR]][] : tensor<i64>
// CHECK: %[[NEXT_SEED:.*]] = arith.muli %[[SEED]], %[[CST127]] : i64
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[TENSOR]][] : tensor<i64>
// CHECK: %[[GLOBAL_SEED_1:.*]] = memref.get_global @global_seed : memref<i64>
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : memref<i64>
// CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_SEED_1]] : memref<i64> to memref<i64>
// CHECK: return %[[NEXT_SEED]] : i64
module {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @forward() -> i64 {
%c127_i64 = arith.constant 127 : i64
%0 = ml_program.global_load @global_seed : tensor<i64>
%extracted = tensor.extract %0[] : tensor<i64>
%1 = arith.muli %extracted, %c127_i64 : i64
%inserted = tensor.insert %1 into %0[] : tensor<i64>
ml_program.global_store @global_seed = %inserted : tensor<i64>
return %1 : i64
}
}
// -----
module {
// expected-error @below {{unsupported global op type}}
ml_program.global private mutable @global_seed(0 : i64) : i64
func.func @forward() -> i64 {
%c127_i64 = arith.constant 127 : i64
%0 = ml_program.global_load @global_seed : i64
%1 = arith.muli %0, %c127_i64 : i64
ml_program.global_store @global_seed = %1 : i64
return %1 : i64
}
}
// -----
module {
// expected-error @below {{unsupported global op type}}
ml_program.global private mutable @global_seed(dense<0> : memref<i64>) : memref<i64>
func.func @forward() -> i64 {
%c127_i64 = arith.constant 127 : i64
%0 = ml_program.global_load @global_seed : memref<i64>
%extracted = memref.load %0[] : memref<i64>
%1 = arith.muli %extracted, %c127_i64 : i64
memref.store %1, %0[] : memref<i64>
ml_program.global_store @global_seed = %0 : memref<i64>
return %1 : i64
}
}
// -----
module {
// expected-error @below {{invalid tensor element type}}
ml_program.global private mutable @global_seed(dense<0> : tensor<memref<i64>>) : tensor<memref<i64>>
func.func @forward() -> i64 {
%c127_i64 = arith.constant 127 : i64
return %c127_i64 : i64
}
}
// -----
module {
// expected-error @below {{unimplemented: global op bufferization with dynamic shape}}
ml_program.global private mutable @global_seed(dense<0> : tensor<1xi64>) : tensor<?xi64>
func.func @forward() -> i64 {
%c127_i64 = arith.constant 127 : i64
%c0 = arith.constant 0 : index
%0 = ml_program.global_load @global_seed : tensor<?xi64>
%extracted = tensor.extract %0[%c0] : tensor<?xi64>
%1 = arith.muli %extracted, %c127_i64 : i64
%inserted = tensor.insert %1 into %0[%c0] : tensor<?xi64>
ml_program.global_store @global_seed = %inserted : tensor<?xi64>
return %1 : i64
}
}