2021-09-23 00:55:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2021-09-30 00:03:40 +08:00
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
2021-09-23 00:55:09 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// The torch-mlir "reference backend" requires a few passes to glue things
|
|
|
|
// together so that the final IR will work with ExecutionEngine.
|
|
|
|
//
|
|
|
|
// There is no actual "backend".
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
2021-11-16 07:00:53 +08:00
|
|
|
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
|
2021-09-23 00:55:09 +08:00
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
2021-10-26 07:16:01 +08:00
|
|
|
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
2021-09-23 00:55:09 +08:00
|
|
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "torch-mlir/RefBackend/Passes.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::RefBackend;
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Pass registration
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
#define GEN_PASS_REGISTRATION
|
|
|
|
#include "torch-mlir/RefBackend/Passes.h.inc"
|
|
|
|
} // end namespace
|
|
|
|
|
|
|
|
void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MungeCallingConventions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-09-24 03:22:28 +08:00
|
|
|
static bool isArgMemRefTypeValid(Type type) {
|
|
|
|
if (auto memRefType = type.dyn_cast<MemRefType>()) {
|
|
|
|
Type elemTy = memRefType.getElementType();
|
|
|
|
if (elemTy.isa<Float32Type>()) {
|
|
|
|
return true;
|
2021-10-16 06:23:59 +08:00
|
|
|
} else if (elemTy.isa<Float64Type>()) {
|
|
|
|
return true;
|
2021-09-24 03:22:28 +08:00
|
|
|
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
|
|
|
if (integerTy.isSignlessInteger(64))
|
|
|
|
return true;
|
2021-10-16 06:23:40 +08:00
|
|
|
if (integerTy.isSignlessInteger(32))
|
|
|
|
return true;
|
2021-09-24 03:22:28 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2021-09-23 00:55:09 +08:00
|
|
|
static void addEmitCInterfaceAttr(FuncOp func) {
|
|
|
|
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getAbiTypeForMemRef(Type type) {
|
|
|
|
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
|
|
|
}
|
|
|
|
|
2021-11-07 03:25:06 +08:00
|
|
|
// Passes the return op operands `val` to `funOp`. Also, adds the op to the
|
|
|
|
// `toErase` vector.
|
|
|
|
static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp,
|
|
|
|
Value val,
|
|
|
|
SmallVectorImpl<Operation *> &toErase) {
|
|
|
|
b.create<mlir::CallOp>(op.getLoc(), funcOp, val);
|
|
|
|
b.create<mlir::ReturnOp>(op.getLoc());
|
|
|
|
toErase.push_back(op);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Checks whether the return op is munge-compatible and the respective calling
|
|
|
|
// function is defined.
|
|
|
|
static bool isReturnOpCompatible(ReturnOp op,
|
|
|
|
DenseMap<Type, FuncOp> &consumeFuncReturnFuncs,
|
|
|
|
Type returnType) {
|
|
|
|
auto it = consumeFuncReturnFuncs.find(returnType);
|
|
|
|
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
|
|
|
|
op.emitError("must have one return value of Memref type or Elemental types "
|
|
|
|
"of i64, f64, f32");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-10-05 10:06:59 +08:00
|
|
|
static LogicalResult mungeFunction(
|
|
|
|
FuncOp func,
|
|
|
|
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
|
2021-09-23 00:55:09 +08:00
|
|
|
// Add `llvm.emit_c_interface`.
|
|
|
|
// This allows ExecutionEngine to resolve the symbol properly.
|
|
|
|
addEmitCInterfaceAttr(func);
|
|
|
|
|
|
|
|
// Rewrite the function as follows:
|
|
|
|
// - replace all memref arguments with unranked memref
|
|
|
|
// - replace all returns with a call to a function, which is going to be
|
|
|
|
// supplied by the code setting up the ExecutionEngine to process the
|
|
|
|
// result. Additionally, ensure that all results are passed as unranked
|
|
|
|
// memrefs.
|
|
|
|
// - replace the function signature accordingly (unranked inputs, no returns).
|
|
|
|
OpBuilder b(func.getBody());
|
|
|
|
|
|
|
|
SmallVector<Type> newArgTypes;
|
|
|
|
for (auto arg : func.getArguments()) {
|
|
|
|
auto type = arg.getType();
|
2021-09-24 03:22:28 +08:00
|
|
|
if (!isArgMemRefTypeValid(type))
|
2021-10-16 06:23:59 +08:00
|
|
|
return emitError(arg.getLoc(),
|
2021-10-16 06:23:40 +08:00
|
|
|
"argument must be a memref of f32, f64, i32, i64");
|
2021-09-23 00:55:09 +08:00
|
|
|
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
|
|
|
|
arg.replaceAllUsesExcept(cast, cast);
|
|
|
|
arg.setType(getAbiTypeForMemRef(type));
|
|
|
|
newArgTypes.push_back(arg.getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Operation *> toErase;
|
2021-11-07 03:25:06 +08:00
|
|
|
bool isCompatible = false;
|
2021-09-23 00:55:09 +08:00
|
|
|
func.walk([&](ReturnOp op) {
|
2021-11-07 03:25:06 +08:00
|
|
|
auto returnType = op.getOperandTypes()[0];
|
2021-10-05 10:06:59 +08:00
|
|
|
|
2021-09-23 00:55:09 +08:00
|
|
|
b.setInsertionPoint(op);
|
2021-11-07 03:25:06 +08:00
|
|
|
// Memref Types.
|
|
|
|
if (auto memrefReturnType = returnType.dyn_cast<MemRefType>()) {
|
|
|
|
auto elemType = memrefReturnType.getElementType();
|
|
|
|
auto unRankedType = UnrankedMemRefType::get(elemType, 0);
|
|
|
|
isCompatible =
|
|
|
|
isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType);
|
|
|
|
if (!isCompatible)
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Cast to unranked memref type before sending it as a function argument.
|
|
|
|
auto cast = b.create<memref::CastOp>(
|
|
|
|
op.getLoc(), op.getOperand(0),
|
|
|
|
getAbiTypeForMemRef(op.getOperandTypes()[0]));
|
|
|
|
replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType],
|
|
|
|
cast.getResult(), toErase);
|
|
|
|
// Elemental types.
|
|
|
|
} else if (returnType.isa<IntegerType>() || returnType.isa<FloatType>()) {
|
|
|
|
isCompatible =
|
|
|
|
isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType);
|
|
|
|
if (!isCompatible)
|
|
|
|
return;
|
|
|
|
replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType],
|
|
|
|
op->getOperand(0), toErase);
|
|
|
|
}
|
2021-09-23 00:55:09 +08:00
|
|
|
});
|
2021-11-07 03:25:06 +08:00
|
|
|
if (!isCompatible)
|
2021-09-23 00:55:09 +08:00
|
|
|
return failure();
|
|
|
|
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
|
|
|
|
for (Operation *op : toErase)
|
|
|
|
op->erase();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class MungeCallingConventions
|
|
|
|
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
auto module = getOperation();
|
|
|
|
OpBuilder b(module.getBodyRegion());
|
2021-10-05 10:06:59 +08:00
|
|
|
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
|
2021-10-16 06:23:59 +08:00
|
|
|
DenseSet<FuncOp> consumeFuncReturnFuncsSet;
|
2021-11-07 03:25:06 +08:00
|
|
|
auto createConsumeFuncReturnFunc = [&](Type returnType,
|
|
|
|
std::string funcName) {
|
2021-10-16 06:23:59 +08:00
|
|
|
auto consumeFuncReturnFunc = b.create<FuncOp>(
|
|
|
|
module.getLoc(), funcName,
|
2021-11-07 03:25:06 +08:00
|
|
|
FunctionType::get(module.getContext(), returnType, {}),
|
2021-10-16 06:23:59 +08:00
|
|
|
b.getStringAttr("private"));
|
|
|
|
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
2021-11-07 03:25:06 +08:00
|
|
|
consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc;
|
2021-10-16 06:23:59 +08:00
|
|
|
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
|
|
|
|
};
|
2021-11-07 03:25:06 +08:00
|
|
|
|
|
|
|
// Memref return types.
|
2021-10-16 06:23:40 +08:00
|
|
|
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
|
|
|
|
"refbackend_consume_memref_int32_func_return");
|
2021-11-07 03:25:06 +08:00
|
|
|
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
|
|
|
|
"refbackend_consume_memref_int64_func_return");
|
|
|
|
createConsumeFuncReturnFunc(
|
|
|
|
UnrankedMemRefType::get(b.getF32Type(), 0),
|
|
|
|
"refbackend_consume_memref_float32_func_return");
|
|
|
|
createConsumeFuncReturnFunc(
|
|
|
|
UnrankedMemRefType::get(b.getF64Type(), 0),
|
|
|
|
"refbackend_consume_memref_float64_func_return");
|
|
|
|
|
|
|
|
// Elemental return types.
|
2021-10-16 06:23:59 +08:00
|
|
|
createConsumeFuncReturnFunc(b.getI64Type(),
|
|
|
|
"refbackend_consume_int64_func_return");
|
|
|
|
createConsumeFuncReturnFunc(b.getF32Type(),
|
|
|
|
"refbackend_consume_float32_func_return");
|
|
|
|
createConsumeFuncReturnFunc(b.getF64Type(),
|
|
|
|
"refbackend_consume_float64_func_return");
|
2021-09-23 00:55:09 +08:00
|
|
|
for (auto func : module.getOps<FuncOp>()) {
|
2021-10-16 06:23:59 +08:00
|
|
|
if (consumeFuncReturnFuncsSet.contains(func))
|
2021-09-23 00:55:09 +08:00
|
|
|
continue;
|
2021-10-05 10:06:59 +08:00
|
|
|
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
|
2021-09-23 00:55:09 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::torch::RefBackend::createMungeCallingConventionsPass() {
|
|
|
|
return std::make_unique<MungeCallingConventions>();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ExpandOpsForLLVM
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
auto func = getOperation();
|
|
|
|
auto *context = &getContext();
|
|
|
|
RewritePatternSet patterns(context);
|
|
|
|
populateExpandTanhPattern(patterns);
|
2021-10-26 07:16:01 +08:00
|
|
|
patterns.add<math::ErfPolynomialApproximation>(patterns.getContext());
|
2021-09-23 00:55:09 +08:00
|
|
|
ConversionTarget target(*context);
|
|
|
|
target.addLegalDialect<StandardOpsDialect>();
|
|
|
|
target.addLegalDialect<math::MathDialect>();
|
2021-10-16 02:34:29 +08:00
|
|
|
target.addLegalDialect<arith::ArithmeticDialect>();
|
2021-09-23 00:55:09 +08:00
|
|
|
target.addIllegalOp<math::TanhOp>();
|
2021-10-26 07:16:01 +08:00
|
|
|
target.addIllegalOp<math::ErfOp>();
|
2021-09-23 00:55:09 +08:00
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>>
|
|
|
|
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
|
|
|
return std::make_unique<ExpandOpsForLLVM>();
|
|
|
|
}
|