mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add TorchConversionToMLProgram and MLProgramBufferize pass
This commit changes the `InsertRngGlobalsPass` to `TorchConversionToMLProgram` pass. This commit also adds the `MLProgramBufferize` pass for the bufferization of ml_program dialect ops to run on refbackend. Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1672/head snapshot-20221202.675
parent
3fc27cf6ca
commit
f416953600
|
@ -125,6 +125,14 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
||||||
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> {
|
||||||
|
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
|
||||||
|
let description = [{
|
||||||
|
Convert TorchConversion ops to mlprogram ops.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
||||||
let summary = "Convert Torch ops to MHLO ops";
|
let summary = "Convert Torch ops to MHLO ops";
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
||||||
|
#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createConvertTorchConversionToMLProgramPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
|
@ -27,7 +27,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,9 @@ def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleO
|
||||||
let dependentDialects = ["memref::MemRefDialect"];
|
let dependentDialects = ["memref::MemRefDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
|
def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> {
|
||||||
let summary = "Insert global variables and sequence to get the next global seed for RNG ops";
|
let summary = "Bufferize the MLProgram dialect ops";
|
||||||
let constructor = "mlir::torch::RefBackend::createInsertRngGlobalsPass();";
|
let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();";
|
||||||
let dependentDialects = ["memref::MemRefDialect"];
|
let dependentDialects = ["memref::MemRefDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
add_subdirectory(TorchToMhlo)
|
add_subdirectory(TorchToMhlo)
|
||||||
endif()
|
endif()
|
||||||
add_subdirectory(TorchToTMTensor)
|
add_subdirectory(TorchToTMTensor)
|
||||||
|
add_subdirectory(TorchConversionToMLProgram)
|
||||||
add_subdirectory(Utils)
|
add_subdirectory(Utils)
|
||||||
|
|
||||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
# TODO: Automate this with add_torch_mlir_conversion_library.
|
||||||
|
@ -14,6 +15,7 @@ set(linked_libs TorchMLIRTorchToLinalg
|
||||||
TorchMLIRTorchToArith
|
TorchMLIRTorchToArith
|
||||||
TorchMLIRTorchToTosa
|
TorchMLIRTorchToTosa
|
||||||
TorchMLIRTorchToTMTensor
|
TorchMLIRTorchToTMTensor
|
||||||
|
TorchMLIRTorchConversionToMLProgram
|
||||||
TorchMLIRConversionUtils)
|
TorchMLIRConversionUtils)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
list(APPEND linked_libs TorchMLIRTorchToMhlo)
|
list(APPEND linked_libs TorchMLIRTorchToMhlo)
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram
|
||||||
|
TorchConversionToMLProgram.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRLinalgDialect
|
||||||
|
MLIRMathDialect
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram)
|
|
@ -0,0 +1,125 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
using namespace mlir::torch::TorchConversion;
|
||||||
|
|
||||||
|
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
|
||||||
|
|
||||||
|
// Declare a tensor<i64> global variable for the seed.
|
||||||
|
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
||||||
|
b.setInsertionPointToStart(module.getBody());
|
||||||
|
Type elemTy = b.getI64Type();
|
||||||
|
auto tensorType = RankedTensorType::get({}, elemTy);
|
||||||
|
b.create<ml_program::GlobalOp>(
|
||||||
|
UnknownLoc::get(b.getContext()),
|
||||||
|
/*sym_name=*/getSeedGobalVarName(),
|
||||||
|
/*type=*/tensorType,
|
||||||
|
/*is_mutable=*/true,
|
||||||
|
/*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}),
|
||||||
|
/*sym_visibility=*/b.getStringAttr("private"));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
// Generate sequence for getting the next seed with LCG step:
|
||||||
|
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
|
||||||
|
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
|
||||||
|
// Get the current seed value.
|
||||||
|
auto tensorType = RankedTensorType::get({}, rewriter.getI64Type());
|
||||||
|
Value globalVar = rewriter.create<ml_program::GlobalLoadOp>(
|
||||||
|
loc, tensorType,
|
||||||
|
SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()));
|
||||||
|
Value currentSeed = rewriter.create<tensor::ExtractOp>(loc, globalVar);
|
||||||
|
|
||||||
|
// The value of multiplier and incrementStep are referenced from
|
||||||
|
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
|
||||||
|
Value multiplier = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(6364136223846793005));
|
||||||
|
Value incrementStep = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1442695040888963407));
|
||||||
|
// temp = multiplier * currentSeed + incrementStep
|
||||||
|
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||||
|
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||||
|
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar);
|
||||||
|
rewriter.create<ml_program::GlobalStoreOp>(
|
||||||
|
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
|
||||||
|
globalVar);
|
||||||
|
rewriter.replaceOp(op, seed);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// The pass
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTorchConversionToMLProgram
|
||||||
|
: public ConvertTorchConversionToMLProgramBase<
|
||||||
|
ConvertTorchConversionToMLProgram> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<tensor::TensorDialect>();
|
||||||
|
registry.insert<arith::ArithDialect>();
|
||||||
|
registry.insert<ml_program::MLProgramDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
target.addLegalDialect<tensor::TensorDialect, arith::ArithDialect,
|
||||||
|
ml_program::MLProgramDialect>();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
|
auto module = getOperation()->getParentOfType<ModuleOp>();
|
||||||
|
OpBuilder b(module.getBodyRegion());
|
||||||
|
createGlobalVariableForSeed(b, module);
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
target.addIllegalOp<GetNextSeedOp>();
|
||||||
|
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
std::move(patterns))))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::createConvertTorchConversionToMLProgramPass() {
|
||||||
|
return std::make_unique<ConvertTorchConversionToMLProgram>();
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ set(LinkedLibs MLIRIR
|
||||||
TorchMLIRTorchToTMTensor
|
TorchMLIRTorchToTMTensor
|
||||||
TorchMLIRTorchToArith
|
TorchMLIRTorchToArith
|
||||||
TorchMLIRTorchToSCF
|
TorchMLIRTorchToSCF
|
||||||
|
TorchMLIRTorchConversionToMLProgram
|
||||||
MLIRMemRefTransforms)
|
MLIRMemRefTransforms)
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
#include "mhlo/transforms/passes.h"
|
#include "mhlo/transforms/passes.h"
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
|
@ -71,6 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchConversionToMLProgramPass());
|
||||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||||
|
|
||||||
// Clean up any non-canonical code introduced above..
|
// Clean up any non-canonical code introduced above..
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -80,6 +81,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
||||||
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
|
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
|
||||||
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
|
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
|
||||||
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
|
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
|
||||||
|
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
|
||||||
|
opHasLegalTypes);
|
||||||
|
|
||||||
// ConstantOp is used for tensors and for scalars.
|
// ConstantOp is used for tensors and for scalars.
|
||||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||||
|
|
|
@ -16,12 +16,14 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
||||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
@ -67,7 +69,6 @@ static bool isArgMemRefTypeValid(Type type) {
|
||||||
return true;
|
return true;
|
||||||
if (integerTy.isSignlessInteger(1))
|
if (integerTy.isSignlessInteger(1))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -219,74 +220,129 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// InsertRngGlobals
|
// MLProgramBufferize
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
|
static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
|
||||||
|
OpBuilder &b) {
|
||||||
|
if (!globalOp.getValue().has_value())
|
||||||
|
return globalOp.emitError("global op must have a value");
|
||||||
|
|
||||||
// Declare a memref<i64> global variable for the seed.
|
RankedTensorType tensorType = globalOp.getType().cast<RankedTensorType>();
|
||||||
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
MemRefType memrefType =
|
||||||
b.setInsertionPointToStart(module.getBody());
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
Type elemTy = b.getI64Type();
|
|
||||||
auto memref0D = MemRefType::get({}, elemTy);
|
b.setInsertionPointToStart(globalOp->getParentOfType<ModuleOp>().getBody());
|
||||||
auto tensor0D = RankedTensorType::get({}, elemTy);
|
|
||||||
b.create<memref::GlobalOp>(
|
b.create<memref::GlobalOp>(
|
||||||
UnknownLoc::get(b.getContext()), getSeedGobalVarName(),
|
UnknownLoc::get(b.getContext()), globalOp.getSymName(),
|
||||||
/*sym_visibility=*/b.getStringAttr("private"),
|
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
|
||||||
/*type=*/memref0D,
|
/*type=*/memrefType,
|
||||||
/*initial_value=*/DenseIntElementsAttr::get(tensor0D, {APInt(64, 0)}),
|
/*initial_value=*/globalOp.getValue().value(),
|
||||||
/*constant=*/false,
|
/*constant=*/globalOp.getIsMutable() ? false : true,
|
||||||
/*alignment=*/nullptr);
|
/*alignment=*/nullptr);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate sequence for getting the next seed with LCG step:
|
static LogicalResult
|
||||||
// nextSeed = (multiplier * currentSeed + incrementStep) mod 64.
|
bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp,
|
||||||
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
|
OpBuilder &b, SmallVector<Operation *> &toErase) {
|
||||||
static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
|
RankedTensorType tensorType = globalLoadOp.getType().cast<RankedTensorType>();
|
||||||
// Get the current seed value.
|
MemRefType memrefType =
|
||||||
auto memref1DType = MemRefType::get({}, b.getI64Type());
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
Value globalVar =
|
|
||||||
b.create<memref::GetGlobalOp>(loc, memref1DType, getSeedGobalVarName());
|
|
||||||
Value currentSeed = b.create<memref::LoadOp>(loc, globalVar);
|
|
||||||
|
|
||||||
// The value of multiplier and incrementStep are referenced from
|
b.setInsertionPoint(globalLoadOp);
|
||||||
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
|
Value globalVal = b.create<memref::GetGlobalOp>(
|
||||||
Value multiplier = b.create<arith::ConstantOp>(
|
globalLoadOp.getLoc(), memrefType,
|
||||||
loc, b.getI64IntegerAttr(6364136223846793005));
|
globalLoadOp.getGlobalAttr().getLeafReference());
|
||||||
Value incrementStep = b.create<arith::ConstantOp>(
|
globalVal = b.create<bufferization::ToTensorOp>(globalLoadOp->getLoc(),
|
||||||
loc, b.getI64IntegerAttr(1442695040888963407));
|
tensorType, globalVal);
|
||||||
// temp = multiplier * currentSeed + incrementStep
|
globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
|
||||||
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
return success();
|
||||||
Value nextSeed = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
}
|
||||||
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
|
|
||||||
return nextSeed;
|
static LogicalResult
|
||||||
|
bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp,
|
||||||
|
OpBuilder &b,
|
||||||
|
SmallVector<Operation *> &toErase) {
|
||||||
|
RankedTensorType tensorType =
|
||||||
|
globalStoreOp.getValue().getType().cast<RankedTensorType>();
|
||||||
|
MemRefType memrefType =
|
||||||
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
|
||||||
|
b.setInsertionPoint(globalStoreOp);
|
||||||
|
Value memref = b.create<memref::GetGlobalOp>(
|
||||||
|
globalStoreOp.getLoc(), memrefType,
|
||||||
|
globalStoreOp.getGlobalAttr().getLeafReference());
|
||||||
|
Value copyValue = b.create<bufferization::ToMemrefOp>(
|
||||||
|
globalStoreOp->getLoc(), memrefType, globalStoreOp.getValue());
|
||||||
|
b.create<memref::CopyOp>(globalStoreOp->getLoc(), copyValue, memref);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// The global seed is stored into a memref<i64> global variable as the only
|
|
||||||
// element.
|
|
||||||
namespace {
|
namespace {
|
||||||
class InsertRngGlobals : public InsertRngGlobalsBase<InsertRngGlobals> {
|
/// Converts MLProgram operations that work on tensor-type operands or results
|
||||||
|
/// to work on buffers.
|
||||||
|
class MLProgramBufferize : public MLProgramBufferizeBase<MLProgramBufferize> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry
|
||||||
|
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
OpBuilder b(module.getBodyRegion());
|
OpBuilder b(module.getBodyRegion());
|
||||||
createGlobalVariableForSeed(b, module);
|
|
||||||
SmallVector<Operation *> toErase;
|
SmallVector<Operation *> toErase;
|
||||||
module.walk([&](TorchConversion::GetNextSeedOp op) {
|
|
||||||
b.setInsertionPoint(op);
|
auto walkResult = module.walk([&](ml_program::GlobalOp op) {
|
||||||
Value seed = lowerGetNextSeed(b, op.getLoc());
|
if (auto type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
op.replaceAllUsesWith(seed);
|
if (!type.hasStaticShape()) {
|
||||||
|
// If the ml_program.global has dynamically shaped tensor.
|
||||||
|
op.emitError(
|
||||||
|
"unimplemented: global op bufferization with dynamic shape");
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If the ml_program.global is of non-tensor type.
|
||||||
|
op.emitError("unsupported global op type");
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (failed(bufferizeMLProgramGlobalOp(op, b))) {
|
||||||
|
op.emitError("bufferization for this op failed");
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
toErase.push_back(op);
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (walkResult.wasInterrupted())
|
||||||
|
return signalPassFailure();
|
||||||
|
|
||||||
|
module.walk([&](ml_program::GlobalLoadOp op) {
|
||||||
|
if (failed(bufferizeMLProgramGlobaLoadOp(op, b, toErase))) {
|
||||||
|
op.emitError("bufferization for this op failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
toErase.push_back(op);
|
toErase.push_back(op);
|
||||||
});
|
});
|
||||||
|
|
||||||
for (auto op : toErase)
|
module.walk([&](ml_program::GlobalStoreOp op) {
|
||||||
|
if (failed(bufferizeMLProgramGlobaStoreOp(op, b, toErase))) {
|
||||||
|
op.emitError("bufferization for this op failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
toErase.push_back(op);
|
||||||
|
});
|
||||||
|
|
||||||
|
for (auto op : llvm::reverse(toErase))
|
||||||
op->erase();
|
op->erase();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::torch::RefBackend::createInsertRngGlobalsPass() {
|
mlir::torch::RefBackend::createMLProgramBufferizePass() {
|
||||||
return std::make_unique<InsertRngGlobals>();
|
return std::make_unique<MLProgramBufferize>();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -123,6 +123,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
||||||
"func.func(linalg-bufferize)",
|
"func.func(linalg-bufferize)",
|
||||||
"func-bufferize",
|
"func-bufferize",
|
||||||
"arith-bufferize",
|
"arith-bufferize",
|
||||||
|
"refback-mlprogram-bufferize",
|
||||||
"func.func(tensor-bufferize)",
|
"func.func(tensor-bufferize)",
|
||||||
"func.func(finalizing-bufferize)",
|
"func.func(finalizing-bufferize)",
|
||||||
# Munge to make it ExecutionEngine compatible.
|
# Munge to make it ExecutionEngine compatible.
|
||||||
|
@ -134,7 +135,6 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
||||||
"refback-munge-calling-conventions",
|
"refback-munge-calling-conventions",
|
||||||
# Insert global variable and instruction sequence for getting the next
|
# Insert global variable and instruction sequence for getting the next
|
||||||
# global seed used in stateful rng.
|
# global seed used in stateful rng.
|
||||||
"refback-insert-rng-globals",
|
|
||||||
# Lower to LLVM
|
# Lower to LLVM
|
||||||
"func.func(tm-tensor-to-loops)",
|
"func.func(tm-tensor-to-loops)",
|
||||||
"func.func(refback-munge-memref-copy)",
|
"func.func(refback-munge-memref-copy)",
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
|
||||||
|
// CHECK-LABEL: func.func @f() -> i64 {
|
||||||
|
// CHECK: %[[GLOBAL:.*]] = ml_program.global_load @global_seed : tensor<i64>
|
||||||
|
// CHECK: %[[SEED:.*]] = tensor.extract %[[GLOBAL]][] : tensor<i64>
|
||||||
|
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
||||||
|
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
||||||
|
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
||||||
|
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
||||||
|
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[GLOBAL]][] : tensor<i64>
|
||||||
|
// CHECK: ml_program.global_store @global_seed = %[[INSERTED]] : tensor<i64>
|
||||||
|
// CHECK: return %2 : i64
|
||||||
|
module {
|
||||||
|
func.func @f() -> i64 {
|
||||||
|
%seed = torch_c.get_next_seed : () -> i64
|
||||||
|
return %seed : i64
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,18 +0,0 @@
|
||||||
// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s
|
|
||||||
|
|
||||||
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
|
|
||||||
// CHECK-LABEL: func.func @f() -> i64 {
|
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64>
|
|
||||||
// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64>
|
|
||||||
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
|
||||||
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
|
||||||
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
|
||||||
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
|
||||||
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
|
|
||||||
// CHECK: return %[[NEXT_SEED]] : i64
|
|
||||||
module {
|
|
||||||
func.func @f() -> i64 {
|
|
||||||
%seed = torch_c.get_next_seed : () -> i64
|
|
||||||
return %seed : i64
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
// RUN: torch-mlir-opt %s -refback-mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
|
||||||
|
// CHECK-LABEL: func.func @forward() -> i64 {
|
||||||
|
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
|
||||||
|
// CHECK: %[[GLOBAL_SEED:.*]] = memref.get_global @global_seed : memref<i64>
|
||||||
|
// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref<i64>
|
||||||
|
// CHECK: %[[SEED:.*]] = tensor.extract %[[TENSOR]][] : tensor<i64>
|
||||||
|
// CHECK: %[[NEXT_SEED:.*]] = arith.muli %[[SEED]], %[[CST127]] : i64
|
||||||
|
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[TENSOR]][] : tensor<i64>
|
||||||
|
// CHECK: %[[GLOBAL_SEED_1:.*]] = memref.get_global @global_seed : memref<i64>
|
||||||
|
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : memref<i64>
|
||||||
|
// CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_SEED_1]] : memref<i64> to memref<i64>
|
||||||
|
// CHECK: return %[[NEXT_SEED]] : i64
|
||||||
|
module {
|
||||||
|
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
|
||||||
|
func.func @forward() -> i64 {
|
||||||
|
%c127_i64 = arith.constant 127 : i64
|
||||||
|
%0 = ml_program.global_load @global_seed : tensor<i64>
|
||||||
|
%extracted = tensor.extract %0[] : tensor<i64>
|
||||||
|
%1 = arith.muli %extracted, %c127_i64 : i64
|
||||||
|
%inserted = tensor.insert %1 into %0[] : tensor<i64>
|
||||||
|
ml_program.global_store @global_seed = %inserted : tensor<i64>
|
||||||
|
return %1 : i64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
// expected-error @below {{unsupported global op type}}
|
||||||
|
ml_program.global private mutable @global_seed(0 : i64) : i64
|
||||||
|
func.func @forward() -> i64 {
|
||||||
|
%c127_i64 = arith.constant 127 : i64
|
||||||
|
%0 = ml_program.global_load @global_seed : i64
|
||||||
|
%1 = arith.muli %0, %c127_i64 : i64
|
||||||
|
ml_program.global_store @global_seed = %1 : i64
|
||||||
|
return %1 : i64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
// expected-error @below {{unsupported global op type}}
|
||||||
|
ml_program.global private mutable @global_seed(dense<0> : memref<i64>) : memref<i64>
|
||||||
|
func.func @forward() -> i64 {
|
||||||
|
%c127_i64 = arith.constant 127 : i64
|
||||||
|
%0 = ml_program.global_load @global_seed : memref<i64>
|
||||||
|
%extracted = memref.load %0[] : memref<i64>
|
||||||
|
%1 = arith.muli %extracted, %c127_i64 : i64
|
||||||
|
memref.store %1, %0[] : memref<i64>
|
||||||
|
ml_program.global_store @global_seed = %0 : memref<i64>
|
||||||
|
return %1 : i64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
// expected-error @below {{invalid tensor element type}}
|
||||||
|
ml_program.global private mutable @global_seed(dense<0> : tensor<memref<i64>>) : tensor<memref<i64>>
|
||||||
|
func.func @forward() -> i64 {
|
||||||
|
%c127_i64 = arith.constant 127 : i64
|
||||||
|
return %c127_i64 : i64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
module {
|
||||||
|
// expected-error @below {{unimplemented: global op bufferization with dynamic shape}}
|
||||||
|
ml_program.global private mutable @global_seed(dense<0> : tensor<1xi64>) : tensor<?xi64>
|
||||||
|
func.func @forward() -> i64 {
|
||||||
|
%c127_i64 = arith.constant 127 : i64
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%0 = ml_program.global_load @global_seed : tensor<?xi64>
|
||||||
|
%extracted = tensor.extract %0[%c0] : tensor<?xi64>
|
||||||
|
%1 = arith.muli %extracted, %c127_i64 : i64
|
||||||
|
%inserted = tensor.insert %1 into %0[%c0] : tensor<?xi64>
|
||||||
|
ml_program.global_store @global_seed = %inserted : tensor<?xi64>
|
||||||
|
return %1 : i64
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue