mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add TorchConversionToMLProgram and MLProgramBufferize pass
This commit changes the `InsertRngGlobalsPass` to `TorchConversionToMLProgram` pass. This commit also adds the `MLProgramBufferize` pass for the bufferization of ml_program dialect ops to run on refbackend. Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1672/head snapshot-20221202.675
parent
3fc27cf6ca
commit
f416953600
|
@ -125,6 +125,14 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
|||
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> {
|
||||
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
|
||||
let description = [{
|
||||
Convert TorchConversion ops to mlprogram ops.
|
||||
}];
|
||||
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
||||
let summary = "Convert Torch ops to MHLO ops";
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createConvertTorchConversionToMLProgramPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
|
@ -27,7 +27,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
|
|||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
|
||||
|
||||
|
|
|
@ -18,9 +18,9 @@ def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleO
|
|||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
|
||||
let summary = "Insert global variables and sequence to get the next global seed for RNG ops";
|
||||
let constructor = "mlir::torch::RefBackend::createInsertRngGlobalsPass();";
|
||||
def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> {
|
||||
let summary = "Bufferize the MLProgram dialect ops";
|
||||
let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ if(TORCH_MLIR_ENABLE_MHLO)
|
|||
add_subdirectory(TorchToMhlo)
|
||||
endif()
|
||||
add_subdirectory(TorchToTMTensor)
|
||||
add_subdirectory(TorchConversionToMLProgram)
|
||||
add_subdirectory(Utils)
|
||||
|
||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
||||
|
@ -14,6 +15,7 @@ set(linked_libs TorchMLIRTorchToLinalg
|
|||
TorchMLIRTorchToArith
|
||||
TorchMLIRTorchToTosa
|
||||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchConversionToMLProgram
|
||||
TorchMLIRConversionUtils)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
list(APPEND linked_libs TorchMLIRTorchToMhlo)
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram
|
||||
TorchConversionToMLProgram.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRLinalgDialect
|
||||
MLIRMathDialect
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram)
|
|
@ -0,0 +1,125 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
|
||||
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
|
||||
|
||||
// Declare a tensor<i64> global variable for the seed.
|
||||
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
||||
b.setInsertionPointToStart(module.getBody());
|
||||
Type elemTy = b.getI64Type();
|
||||
auto tensorType = RankedTensorType::get({}, elemTy);
|
||||
b.create<ml_program::GlobalOp>(
|
||||
UnknownLoc::get(b.getContext()),
|
||||
/*sym_name=*/getSeedGobalVarName(),
|
||||
/*type=*/tensorType,
|
||||
/*is_mutable=*/true,
|
||||
/*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}),
|
||||
/*sym_visibility=*/b.getStringAttr("private"));
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
// Generate sequence for getting the next seed with LCG step:
|
||||
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
|
||||
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
|
||||
// Get the current seed value.
|
||||
auto tensorType = RankedTensorType::get({}, rewriter.getI64Type());
|
||||
Value globalVar = rewriter.create<ml_program::GlobalLoadOp>(
|
||||
loc, tensorType,
|
||||
SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()));
|
||||
Value currentSeed = rewriter.create<tensor::ExtractOp>(loc, globalVar);
|
||||
|
||||
// The value of multiplier and incrementStep are referenced from
|
||||
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
|
||||
Value multiplier = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(6364136223846793005));
|
||||
Value incrementStep = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1442695040888963407));
|
||||
// temp = multiplier * currentSeed + incrementStep
|
||||
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar);
|
||||
rewriter.create<ml_program::GlobalStoreOp>(
|
||||
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
|
||||
globalVar);
|
||||
rewriter.replaceOp(op, seed);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
class ConvertTorchConversionToMLProgram
|
||||
: public ConvertTorchConversionToMLProgramBase<
|
||||
ConvertTorchConversionToMLProgram> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
registry.insert<ml_program::MLProgramDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<tensor::TensorDialect, arith::ArithDialect,
|
||||
ml_program::MLProgramDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
auto module = getOperation()->getParentOfType<ModuleOp>();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
createGlobalVariableForSeed(b, module);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<GetNextSeedOp>();
|
||||
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchConversionToMLProgramPass() {
|
||||
return std::make_unique<ConvertTorchConversionToMLProgram>();
|
||||
}
|
|
@ -8,6 +8,7 @@ set(LinkedLibs MLIRIR
|
|||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchToArith
|
||||
TorchMLIRTorchToSCF
|
||||
TorchMLIRTorchConversionToMLProgram
|
||||
MLIRMemRefTransforms)
|
||||
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
|
@ -71,6 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchConversionToMLProgramPass());
|
||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||
|
||||
// Clean up any non-canonical code introduced above..
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -80,6 +81,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
|
||||
opHasLegalTypes);
|
||||
|
||||
// ConstantOp is used for tensors and for scalars.
|
||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||
|
|
|
@ -16,12 +16,14 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
|
@ -67,7 +69,6 @@ static bool isArgMemRefTypeValid(Type type) {
|
|||
return true;
|
||||
if (integerTy.isSignlessInteger(1))
|
||||
return true;
|
||||
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
@ -219,74 +220,129 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertRngGlobals
|
||||
// MLProgramBufferize
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
|
||||
static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
|
||||
OpBuilder &b) {
|
||||
if (!globalOp.getValue().has_value())
|
||||
return globalOp.emitError("global op must have a value");
|
||||
|
||||
// Declare a memref<i64> global variable for the seed.
|
||||
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
||||
b.setInsertionPointToStart(module.getBody());
|
||||
Type elemTy = b.getI64Type();
|
||||
auto memref0D = MemRefType::get({}, elemTy);
|
||||
auto tensor0D = RankedTensorType::get({}, elemTy);
|
||||
RankedTensorType tensorType = globalOp.getType().cast<RankedTensorType>();
|
||||
MemRefType memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
|
||||
b.setInsertionPointToStart(globalOp->getParentOfType<ModuleOp>().getBody());
|
||||
b.create<memref::GlobalOp>(
|
||||
UnknownLoc::get(b.getContext()), getSeedGobalVarName(),
|
||||
/*sym_visibility=*/b.getStringAttr("private"),
|
||||
/*type=*/memref0D,
|
||||
/*initial_value=*/DenseIntElementsAttr::get(tensor0D, {APInt(64, 0)}),
|
||||
/*constant=*/false,
|
||||
UnknownLoc::get(b.getContext()), globalOp.getSymName(),
|
||||
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
|
||||
/*type=*/memrefType,
|
||||
/*initial_value=*/globalOp.getValue().value(),
|
||||
/*constant=*/globalOp.getIsMutable() ? false : true,
|
||||
/*alignment=*/nullptr);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Generate sequence for getting the next seed with LCG step:
|
||||
// nextSeed = (multiplier * currentSeed + incrementStep) mod 64.
|
||||
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
|
||||
static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
|
||||
// Get the current seed value.
|
||||
auto memref1DType = MemRefType::get({}, b.getI64Type());
|
||||
Value globalVar =
|
||||
b.create<memref::GetGlobalOp>(loc, memref1DType, getSeedGobalVarName());
|
||||
Value currentSeed = b.create<memref::LoadOp>(loc, globalVar);
|
||||
static LogicalResult
|
||||
bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp,
|
||||
OpBuilder &b, SmallVector<Operation *> &toErase) {
|
||||
RankedTensorType tensorType = globalLoadOp.getType().cast<RankedTensorType>();
|
||||
MemRefType memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
|
||||
// The value of multiplier and incrementStep are referenced from
|
||||
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
|
||||
Value multiplier = b.create<arith::ConstantOp>(
|
||||
loc, b.getI64IntegerAttr(6364136223846793005));
|
||||
Value incrementStep = b.create<arith::ConstantOp>(
|
||||
loc, b.getI64IntegerAttr(1442695040888963407));
|
||||
// temp = multiplier * currentSeed + incrementStep
|
||||
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||
Value nextSeed = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
|
||||
return nextSeed;
|
||||
b.setInsertionPoint(globalLoadOp);
|
||||
Value globalVal = b.create<memref::GetGlobalOp>(
|
||||
globalLoadOp.getLoc(), memrefType,
|
||||
globalLoadOp.getGlobalAttr().getLeafReference());
|
||||
globalVal = b.create<bufferization::ToTensorOp>(globalLoadOp->getLoc(),
|
||||
tensorType, globalVal);
|
||||
globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp,
|
||||
OpBuilder &b,
|
||||
SmallVector<Operation *> &toErase) {
|
||||
RankedTensorType tensorType =
|
||||
globalStoreOp.getValue().getType().cast<RankedTensorType>();
|
||||
MemRefType memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
|
||||
b.setInsertionPoint(globalStoreOp);
|
||||
Value memref = b.create<memref::GetGlobalOp>(
|
||||
globalStoreOp.getLoc(), memrefType,
|
||||
globalStoreOp.getGlobalAttr().getLeafReference());
|
||||
Value copyValue = b.create<bufferization::ToMemrefOp>(
|
||||
globalStoreOp->getLoc(), memrefType, globalStoreOp.getValue());
|
||||
b.create<memref::CopyOp>(globalStoreOp->getLoc(), copyValue, memref);
|
||||
return success();
|
||||
}
|
||||
|
||||
// The global seed is stored into a memref<i64> global variable as the only
|
||||
// element.
|
||||
namespace {
|
||||
class InsertRngGlobals : public InsertRngGlobalsBase<InsertRngGlobals> {
|
||||
/// Converts MLProgram operations that work on tensor-type operands or results
|
||||
/// to work on buffers.
|
||||
class MLProgramBufferize : public MLProgramBufferizeBase<MLProgramBufferize> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
createGlobalVariableForSeed(b, module);
|
||||
SmallVector<Operation *> toErase;
|
||||
module.walk([&](TorchConversion::GetNextSeedOp op) {
|
||||
b.setInsertionPoint(op);
|
||||
Value seed = lowerGetNextSeed(b, op.getLoc());
|
||||
op.replaceAllUsesWith(seed);
|
||||
|
||||
auto walkResult = module.walk([&](ml_program::GlobalOp op) {
|
||||
if (auto type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (!type.hasStaticShape()) {
|
||||
// If the ml_program.global has dynamically shaped tensor.
|
||||
op.emitError(
|
||||
"unimplemented: global op bufferization with dynamic shape");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
} else {
|
||||
// If the ml_program.global is of non-tensor type.
|
||||
op.emitError("unsupported global op type");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
if (failed(bufferizeMLProgramGlobalOp(op, b))) {
|
||||
op.emitError("bufferization for this op failed");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
toErase.push_back(op);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walkResult.wasInterrupted())
|
||||
return signalPassFailure();
|
||||
|
||||
module.walk([&](ml_program::GlobalLoadOp op) {
|
||||
if (failed(bufferizeMLProgramGlobaLoadOp(op, b, toErase))) {
|
||||
op.emitError("bufferization for this op failed");
|
||||
return;
|
||||
}
|
||||
toErase.push_back(op);
|
||||
});
|
||||
|
||||
for (auto op : toErase)
|
||||
module.walk([&](ml_program::GlobalStoreOp op) {
|
||||
if (failed(bufferizeMLProgramGlobaStoreOp(op, b, toErase))) {
|
||||
op.emitError("bufferization for this op failed");
|
||||
return;
|
||||
}
|
||||
toErase.push_back(op);
|
||||
});
|
||||
|
||||
for (auto op : llvm::reverse(toErase))
|
||||
op->erase();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::RefBackend::createInsertRngGlobalsPass() {
|
||||
return std::make_unique<InsertRngGlobals>();
|
||||
mlir::torch::RefBackend::createMLProgramBufferizePass() {
|
||||
return std::make_unique<MLProgramBufferize>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -123,6 +123,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
|||
"func.func(linalg-bufferize)",
|
||||
"func-bufferize",
|
||||
"arith-bufferize",
|
||||
"refback-mlprogram-bufferize",
|
||||
"func.func(tensor-bufferize)",
|
||||
"func.func(finalizing-bufferize)",
|
||||
# Munge to make it ExecutionEngine compatible.
|
||||
|
@ -134,7 +135,6 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
|||
"refback-munge-calling-conventions",
|
||||
# Insert global variable and instruction sequence for getting the next
|
||||
# global seed used in stateful rng.
|
||||
"refback-insert-rng-globals",
|
||||
# Lower to LLVM
|
||||
"func.func(tm-tensor-to-loops)",
|
||||
"func.func(refback-munge-memref-copy)",
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
|
||||
// CHECK-LABEL: func.func @f() -> i64 {
|
||||
// CHECK: %[[GLOBAL:.*]] = ml_program.global_load @global_seed : tensor<i64>
|
||||
// CHECK: %[[SEED:.*]] = tensor.extract %[[GLOBAL]][] : tensor<i64>
|
||||
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
||||
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
||||
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
||||
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
||||
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[GLOBAL]][] : tensor<i64>
|
||||
// CHECK: ml_program.global_store @global_seed = %[[INSERTED]] : tensor<i64>
|
||||
// CHECK: return %2 : i64
|
||||
module {
|
||||
func.func @f() -> i64 {
|
||||
%seed = torch_c.get_next_seed : () -> i64
|
||||
return %seed : i64
|
||||
}
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
|
||||
// CHECK-LABEL: func.func @f() -> i64 {
|
||||
// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64>
|
||||
// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64>
|
||||
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
||||
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
||||
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
||||
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
||||
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
|
||||
// CHECK: return %[[NEXT_SEED]] : i64
|
||||
module {
|
||||
func.func @f() -> i64 {
|
||||
%seed = torch_c.get_next_seed : () -> i64
|
||||
return %seed : i64
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
// RUN: torch-mlir-opt %s -refback-mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
|
||||
// CHECK-LABEL: func.func @forward() -> i64 {
|
||||
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
|
||||
// CHECK: %[[GLOBAL_SEED:.*]] = memref.get_global @global_seed : memref<i64>
|
||||
// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref<i64>
|
||||
// CHECK: %[[SEED:.*]] = tensor.extract %[[TENSOR]][] : tensor<i64>
|
||||
// CHECK: %[[NEXT_SEED:.*]] = arith.muli %[[SEED]], %[[CST127]] : i64
|
||||
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[TENSOR]][] : tensor<i64>
|
||||
// CHECK: %[[GLOBAL_SEED_1:.*]] = memref.get_global @global_seed : memref<i64>
|
||||
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : memref<i64>
|
||||
// CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_SEED_1]] : memref<i64> to memref<i64>
|
||||
// CHECK: return %[[NEXT_SEED]] : i64
|
||||
module {
|
||||
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
|
||||
func.func @forward() -> i64 {
|
||||
%c127_i64 = arith.constant 127 : i64
|
||||
%0 = ml_program.global_load @global_seed : tensor<i64>
|
||||
%extracted = tensor.extract %0[] : tensor<i64>
|
||||
%1 = arith.muli %extracted, %c127_i64 : i64
|
||||
%inserted = tensor.insert %1 into %0[] : tensor<i64>
|
||||
ml_program.global_store @global_seed = %inserted : tensor<i64>
|
||||
return %1 : i64
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @below {{unsupported global op type}}
|
||||
ml_program.global private mutable @global_seed(0 : i64) : i64
|
||||
func.func @forward() -> i64 {
|
||||
%c127_i64 = arith.constant 127 : i64
|
||||
%0 = ml_program.global_load @global_seed : i64
|
||||
%1 = arith.muli %0, %c127_i64 : i64
|
||||
ml_program.global_store @global_seed = %1 : i64
|
||||
return %1 : i64
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @below {{unsupported global op type}}
|
||||
ml_program.global private mutable @global_seed(dense<0> : memref<i64>) : memref<i64>
|
||||
func.func @forward() -> i64 {
|
||||
%c127_i64 = arith.constant 127 : i64
|
||||
%0 = ml_program.global_load @global_seed : memref<i64>
|
||||
%extracted = memref.load %0[] : memref<i64>
|
||||
%1 = arith.muli %extracted, %c127_i64 : i64
|
||||
memref.store %1, %0[] : memref<i64>
|
||||
ml_program.global_store @global_seed = %0 : memref<i64>
|
||||
return %1 : i64
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @below {{invalid tensor element type}}
|
||||
ml_program.global private mutable @global_seed(dense<0> : tensor<memref<i64>>) : tensor<memref<i64>>
|
||||
func.func @forward() -> i64 {
|
||||
%c127_i64 = arith.constant 127 : i64
|
||||
return %c127_i64 : i64
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
module {
|
||||
// expected-error @below {{unimplemented: global op bufferization with dynamic shape}}
|
||||
ml_program.global private mutable @global_seed(dense<0> : tensor<1xi64>) : tensor<?xi64>
|
||||
func.func @forward() -> i64 {
|
||||
%c127_i64 = arith.constant 127 : i64
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = ml_program.global_load @global_seed : tensor<?xi64>
|
||||
%extracted = tensor.extract %0[%c0] : tensor<?xi64>
|
||||
%1 = arith.muli %extracted, %c127_i64 : i64
|
||||
%inserted = tensor.insert %1 into %0[%c0] : tensor<?xi64>
|
||||
ml_program.global_store @global_seed = %inserted : tensor<?xi64>
|
||||
return %1 : i64
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue