2021-09-23 00:55:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2021-09-30 00:03:40 +08:00
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
2021-09-23 00:55:09 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// The torch-mlir "reference backend" requires a few passes to glue things
|
|
|
|
// together so that the final IR will work with ExecutionEngine.
|
|
|
|
//
|
|
|
|
// There is no actual "backend".
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
2022-04-27 03:27:51 +08:00
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
2022-03-16 18:44:23 +08:00
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
2022-02-13 02:47:12 +08:00
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
2022-03-16 18:44:23 +08:00
|
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
2021-09-23 00:55:09 +08:00
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
2021-10-26 07:16:01 +08:00
|
|
|
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
2021-09-23 00:55:09 +08:00
|
|
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2022-02-13 02:47:12 +08:00
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
2022-01-12 09:59:42 +08:00
|
|
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
2022-02-13 02:47:12 +08:00
|
|
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
2021-09-23 00:55:09 +08:00
|
|
|
#include "torch-mlir/RefBackend/Passes.h"
|
2021-11-08 23:56:40 +08:00
|
|
|
#include <numeric>
|
2022-01-12 09:59:42 +08:00
|
|
|
#include <set>
|
2021-09-23 00:55:09 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::RefBackend;
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Pass registration
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
#define GEN_PASS_REGISTRATION
|
|
|
|
#include "torch-mlir/RefBackend/Passes.h.inc"
|
|
|
|
} // end namespace
|
|
|
|
|
|
|
|
void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MungeCallingConventions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-09-24 03:22:28 +08:00
|
|
|
static bool isArgMemRefTypeValid(Type type) {
|
|
|
|
if (auto memRefType = type.dyn_cast<MemRefType>()) {
|
|
|
|
Type elemTy = memRefType.getElementType();
|
|
|
|
if (elemTy.isa<Float32Type>()) {
|
|
|
|
return true;
|
2021-10-16 06:23:59 +08:00
|
|
|
} else if (elemTy.isa<Float64Type>()) {
|
|
|
|
return true;
|
2021-09-24 03:22:28 +08:00
|
|
|
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
|
|
|
if (integerTy.isSignlessInteger(64))
|
|
|
|
return true;
|
2021-10-16 06:23:40 +08:00
|
|
|
if (integerTy.isSignlessInteger(32))
|
|
|
|
return true;
|
2021-12-08 22:05:02 +08:00
|
|
|
if (integerTy.isSignlessInteger(1))
|
|
|
|
return true;
|
2021-09-24 03:22:28 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
static void addEmitCInterfaceAttr(func::FuncOp func) {
|
2021-09-23 00:55:09 +08:00
|
|
|
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getAbiTypeForMemRef(Type type) {
|
|
|
|
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
|
|
|
}
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
// Helper function to get the type string for one return value like i32, f64,
|
|
|
|
// mri32 etc. The strings from multiple return values are concatenated to get
|
|
|
|
// the consumeFuncReturnFunc name.
|
|
|
|
static std::string getTypeToken(Type type) {
|
|
|
|
if (type.isSignlessInteger())
|
|
|
|
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
|
|
|
|
else if (type.isa<mlir::FloatType>())
|
|
|
|
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
|
|
|
|
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
|
|
|
|
return "mr" + getTypeToken(memRefType.getElementType());
|
|
|
|
|
|
|
|
llvm_unreachable(
|
|
|
|
"Type token should handle all types: memref, float and int type");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Systematically derive the consumeFuncReturnFunc name from return value types.
|
|
|
|
static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
|
|
|
|
SmallVector<std::string> tokens = {"refbackend_consume_func_return"};
|
|
|
|
for (auto type : types)
|
|
|
|
tokens.push_back(getTypeToken(type));
|
|
|
|
|
|
|
|
return std::accumulate(tokens.begin(), tokens.end(), std::string(),
|
|
|
|
[](std::string &a, std::string &b) {
|
|
|
|
return a.empty() ? b : (a + "_" + b);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
|
|
|
|
// the op to the `toErase` vector.
|
2022-03-16 18:44:23 +08:00
|
|
|
static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
|
|
|
|
StringRef funcName, TypeRange retTypes,
|
2021-11-08 23:56:40 +08:00
|
|
|
SmallVectorImpl<Value> &vals,
|
2021-11-07 03:25:06 +08:00
|
|
|
SmallVectorImpl<Operation *> &toErase) {
|
2022-03-16 18:44:23 +08:00
|
|
|
b.create<mlir::func::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
|
|
|
|
b.create<mlir::func::ReturnOp>(op.getLoc());
|
2021-11-07 03:25:06 +08:00
|
|
|
toErase.push_back(op);
|
|
|
|
}
|
|
|
|
|
2021-10-05 10:06:59 +08:00
|
|
|
static LogicalResult mungeFunction(
|
2022-04-27 03:27:51 +08:00
|
|
|
func::FuncOp func,
|
2021-11-08 23:56:40 +08:00
|
|
|
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
|
2022-01-12 09:59:42 +08:00
|
|
|
// Only need to call mungeFunction for functions callable from outside of the
|
|
|
|
// module.
|
|
|
|
if (func.isPrivate())
|
|
|
|
return success();
|
2021-09-23 00:55:09 +08:00
|
|
|
// Add `llvm.emit_c_interface`.
|
|
|
|
// This allows ExecutionEngine to resolve the symbol properly.
|
|
|
|
addEmitCInterfaceAttr(func);
|
|
|
|
|
|
|
|
// Rewrite the function as follows:
|
|
|
|
// - replace all memref arguments with unranked memref
|
|
|
|
// - replace all returns with a call to a function, which is going to be
|
|
|
|
// supplied by the code setting up the ExecutionEngine to process the
|
|
|
|
// result. Additionally, ensure that all results are passed as unranked
|
|
|
|
// memrefs.
|
|
|
|
// - replace the function signature accordingly (unranked inputs, no returns).
|
|
|
|
OpBuilder b(func.getBody());
|
|
|
|
|
|
|
|
SmallVector<Type> newArgTypes;
|
|
|
|
for (auto arg : func.getArguments()) {
|
|
|
|
auto type = arg.getType();
|
2021-09-24 03:22:28 +08:00
|
|
|
if (!isArgMemRefTypeValid(type))
|
2021-10-16 06:23:59 +08:00
|
|
|
return emitError(arg.getLoc(),
|
2021-12-08 22:05:02 +08:00
|
|
|
"argument must be a memref of f32, f64, i32, i64, i1");
|
2022-02-13 02:47:12 +08:00
|
|
|
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
|
2021-09-23 00:55:09 +08:00
|
|
|
arg.replaceAllUsesExcept(cast, cast);
|
|
|
|
arg.setType(getAbiTypeForMemRef(type));
|
|
|
|
newArgTypes.push_back(arg.getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Operation *> toErase;
|
2022-03-16 18:44:23 +08:00
|
|
|
func.walk([&](func::ReturnOp op) {
|
2021-11-08 23:56:40 +08:00
|
|
|
auto types = op.getOperandTypes();
|
2021-09-23 00:55:09 +08:00
|
|
|
b.setInsertionPoint(op);
|
2021-11-07 03:25:06 +08:00
|
|
|
// Memref Types.
|
2021-11-08 23:56:40 +08:00
|
|
|
std::vector<Type> retTypes;
|
|
|
|
SmallVector<Value> retVals;
|
|
|
|
for (auto en : llvm::enumerate(types)) {
|
|
|
|
Type retType = en.value();
|
|
|
|
Value retVal = op.getOperand(en.index());
|
|
|
|
if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
|
|
|
|
auto elemType = memrefReturnType.getElementType();
|
|
|
|
retType = UnrankedMemRefType::get(elemType, 0);
|
|
|
|
// Cast to unranked memref type before sending it as a function
|
|
|
|
// argument.
|
|
|
|
retVal = b.create<memref::CastOp>(
|
2022-02-13 02:47:12 +08:00
|
|
|
op.getLoc(), getAbiTypeForMemRef(types[en.index()]), retVal);
|
2021-11-08 23:56:40 +08:00
|
|
|
}
|
|
|
|
retTypes.push_back(retType);
|
|
|
|
retVals.push_back(retVal);
|
2021-11-07 03:25:06 +08:00
|
|
|
}
|
2021-11-08 23:56:40 +08:00
|
|
|
|
|
|
|
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
|
|
|
|
|
|
|
|
auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end();
|
|
|
|
if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd)
|
|
|
|
invokedConsumeFuncReturnFuncs.insert({funcName, retTypes});
|
|
|
|
replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase);
|
2021-09-23 00:55:09 +08:00
|
|
|
});
|
|
|
|
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
|
|
|
|
for (Operation *op : toErase)
|
|
|
|
op->erase();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class MungeCallingConventions
|
|
|
|
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
auto module = getOperation();
|
|
|
|
OpBuilder b(module.getBodyRegion());
|
2021-11-08 23:56:40 +08:00
|
|
|
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
|
2022-04-27 03:27:51 +08:00
|
|
|
for (auto func : module.getOps<func::FuncOp>()) {
|
2022-04-19 21:47:47 +08:00
|
|
|
if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs)))
|
2021-09-23 00:55:09 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
2021-11-08 23:56:40 +08:00
|
|
|
|
|
|
|
// Create FuncOp for consumeFuncReturnFuncs that are used.
|
|
|
|
for (auto &p : invokedConsumeFuncReturnFuncs) {
|
2022-04-27 03:27:51 +08:00
|
|
|
auto consumeFuncReturnFunc = b.create<func::FuncOp>(
|
|
|
|
module.getLoc(), p.first,
|
|
|
|
FunctionType::get(module.getContext(), p.second, {}),
|
|
|
|
b.getStringAttr("private"));
|
2021-11-08 23:56:40 +08:00
|
|
|
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
|
|
|
}
|
2021-09-23 00:55:09 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::torch::RefBackend::createMungeCallingConventionsPass() {
|
|
|
|
return std::make_unique<MungeCallingConventions>();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-01-12 09:59:42 +08:00
|
|
|
// InsertRngGlobals
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
|
|
|
|
|
|
|
|
// Declare a memref<i64> global variable for the seed.
|
|
|
|
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
|
|
|
b.setInsertionPointToStart(module.getBody());
|
|
|
|
Type elemTy = b.getI64Type();
|
|
|
|
auto memref0D = MemRefType::get({}, elemTy);
|
|
|
|
auto tensor0D = RankedTensorType::get({}, elemTy);
|
|
|
|
b.create<memref::GlobalOp>(
|
|
|
|
UnknownLoc::get(b.getContext()), getSeedGobalVarName(),
|
|
|
|
/*sym_visibility=*/b.getStringAttr("private"),
|
|
|
|
/*type=*/memref0D,
|
|
|
|
/*initial_value=*/DenseIntElementsAttr::get(tensor0D, {APInt(64, 0)}),
|
|
|
|
/*constant=*/false,
|
|
|
|
/*alignment=*/nullptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Generate sequence for getting the next seed with LCG step:
|
|
|
|
// nextSeed = (multiplier * currentSeed + incrementStep) mod 64.
|
|
|
|
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
|
|
|
|
static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
|
|
|
|
// Get the current seed value.
|
|
|
|
auto memref1DType = MemRefType::get({}, b.getI64Type());
|
|
|
|
Value globalVar =
|
|
|
|
b.create<memref::GetGlobalOp>(loc, memref1DType, getSeedGobalVarName());
|
|
|
|
Value currentSeed = b.create<memref::LoadOp>(loc, globalVar);
|
|
|
|
|
|
|
|
// The value of multiplier and incrementStep are referenced from
|
|
|
|
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
|
|
|
|
Value multiplier = b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getI64IntegerAttr(6364136223846793005));
|
|
|
|
Value incrementStep = b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getI64IntegerAttr(1442695040888963407));
|
|
|
|
// temp = multiplier * currentSeed + incrementStep
|
|
|
|
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
|
|
|
Value temp = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
|
|
|
// temp mod 64 = temp & 63
|
|
|
|
Value cst127 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(127));
|
|
|
|
Value nextSeed = b.create<arith::AndIOp>(loc, temp, cst127);
|
|
|
|
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
|
|
|
|
return nextSeed;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The global seed is stored into a memref<i64> global variable as the only
|
|
|
|
// element.
|
|
|
|
namespace {
|
|
|
|
class InsertRngGlobals : public InsertRngGlobalsBase<InsertRngGlobals> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
auto module = getOperation();
|
|
|
|
OpBuilder b(module.getBodyRegion());
|
|
|
|
createGlobalVariableForSeed(b, module);
|
|
|
|
SmallVector<Operation *> toErase;
|
|
|
|
module.walk([&](TorchConversion::GetNextSeedOp op) {
|
|
|
|
b.setInsertionPoint(op);
|
|
|
|
Value seed = lowerGetNextSeed(b, op.getLoc());
|
|
|
|
op.replaceAllUsesWith(seed);
|
|
|
|
toErase.push_back(op);
|
|
|
|
});
|
|
|
|
|
|
|
|
for (auto op : toErase)
|
|
|
|
op->erase();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::torch::RefBackend::createInsertRngGlobalsPass() {
|
|
|
|
return std::make_unique<InsertRngGlobals>();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2021-09-23 00:55:09 +08:00
|
|
|
// ExpandOpsForLLVM
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
auto func = getOperation();
|
|
|
|
auto *context = &getContext();
|
|
|
|
RewritePatternSet patterns(context);
|
|
|
|
populateExpandTanhPattern(patterns);
|
2021-10-26 07:16:01 +08:00
|
|
|
patterns.add<math::ErfPolynomialApproximation>(patterns.getContext());
|
2021-09-23 00:55:09 +08:00
|
|
|
ConversionTarget target(*context);
|
2022-03-16 18:44:23 +08:00
|
|
|
target.addLegalDialect<func::FuncDialect>();
|
2021-09-23 00:55:09 +08:00
|
|
|
target.addLegalDialect<math::MathDialect>();
|
2021-10-16 02:34:29 +08:00
|
|
|
target.addLegalDialect<arith::ArithmeticDialect>();
|
2021-09-23 00:55:09 +08:00
|
|
|
target.addIllegalOp<math::TanhOp>();
|
2021-10-26 07:16:01 +08:00
|
|
|
target.addIllegalOp<math::ErfOp>();
|
2021-09-23 00:55:09 +08:00
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
2021-09-23 00:55:09 +08:00
|
|
|
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
|
|
|
return std::make_unique<ExpandOpsForLLVM>();
|
|
|
|
}
|
2022-02-13 02:47:12 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MungeMemrefCopy
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
|
|
|
|
Value to) {
|
|
|
|
auto memrefTypeFrom = from.getType().cast<MemRefType>();
|
|
|
|
auto memrefTypeTo = to.getType().cast<MemRefType>();
|
|
|
|
(void)memrefTypeFrom;
|
|
|
|
assert(memrefTypeFrom && memrefTypeTo &&
|
|
|
|
memrefTypeFrom.getRank() == memrefTypeTo.getRank());
|
|
|
|
AffineMap id =
|
|
|
|
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
|
|
|
|
SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
|
|
|
|
getParallelIteratorTypeName());
|
|
|
|
return b.create<linalg::GenericOp>(
|
|
|
|
loc,
|
|
|
|
/*inputs=*/from,
|
|
|
|
/*outputs=*/to,
|
|
|
|
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
|
|
|
|
/*iteratorTypes=*/iteratorTypes,
|
|
|
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
b.create<linalg::YieldOp>(loc, args.front());
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
|
|
|
|
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Operation *linalgCopy = createLinalgCopyOp(
|
|
|
|
rewriter, copyOp.getLoc(), copyOp.source(), copyOp.target());
|
|
|
|
rewriter.replaceOp(copyOp, linalgCopy->getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
|
2022-03-16 18:44:23 +08:00
|
|
|
void runOnOperation() override {
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
RewritePatternSet patterns(&getContext());
|
|
|
|
patterns.insert<MemrefCopyOpToLinalg>(context);
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
|
|
|
std::move(patterns)))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
2022-03-16 18:44:23 +08:00
|
|
|
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
|
|
|
|
return std::make_unique<MungeMemrefCopy>();
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class GeneralizeTensorPad
|
|
|
|
: public GeneralizeTensorPadBase<GeneralizeTensorPad> {
|
2022-02-13 02:47:12 +08:00
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
|
|
registry.insert<linalg::LinalgDialect>();
|
|
|
|
}
|
|
|
|
|
|
|
|
void runOnOperation() override {
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
RewritePatternSet patterns(&getContext());
|
2022-03-16 18:44:23 +08:00
|
|
|
patterns.insert<linalg::GeneralizePadOpPattern>(context);
|
2022-02-13 02:47:12 +08:00
|
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
|
|
|
std::move(patterns)))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
2022-03-16 18:44:23 +08:00
|
|
|
mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
|
|
|
|
return std::make_unique<GeneralizeTensorPad>();
|
2022-02-13 02:47:12 +08:00
|
|
|
}
|