add i64 support to refbackend

pull/351/head
dan 2021-10-05 02:06:59 +00:00 committed by Yi Zhang
parent fadd76e9b8
commit 2e1498ad11
3 changed files with 50 additions and 27 deletions

View File

@ -53,15 +53,6 @@ static bool isArgMemRefTypeValid(Type type) {
return false;
}
static bool isReturnMemRefTypeValid(Type type) {
if (auto memRefType = type.dyn_cast<MemRefType>()) {
if (memRefType.getElementType().isa<Float32Type>()) {
return true;
}
}
return false;
}
static void addEmitCInterfaceAttr(FuncOp func) {
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
}
@ -70,7 +61,9 @@ static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
}
static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
static LogicalResult mungeFunction(
FuncOp func,
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
// Add `llvm.emit_c_interface`.
// This allows ExecutionEngine to resolve the symbol properly.
addEmitCInterfaceAttr(func);
@ -98,17 +91,20 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
SmallVector<Operation *> toErase;
bool hadError = false;
func.walk([&](ReturnOp op) {
if (op.getNumOperands() != 1 ||
!isReturnMemRefTypeValid(op.getOperandTypes()[0])) {
auto returnType =
op.getOperandTypes()[0].dyn_cast<MemRefType>().getElementType();
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
hadError = true;
op.emitError("must have one return value and it must be a memref of f32");
op.emitError("must have one return value: a memref of f32 or i64");
return;
}
b.setInsertionPoint(op);
auto cast =
b.create<memref::CastOp>(op.getLoc(), op.getOperand(0),
getAbiTypeForMemRef(op.getOperandTypes()[0]));
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFunc,
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFuncs[returnType],
cast.getResult());
b.create<mlir::ReturnOp>(op.getLoc());
toErase.push_back(op);
@ -127,22 +123,31 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
namespace {
class MungeCallingConventions
: public MungeCallingConventionsBase<MungeCallingConventions> {
void runOnOperation() override {
auto module = getOperation();
OpBuilder b(module.getBodyRegion());
auto consumeFuncReturnFunc = b.create<FuncOp>(
module.getLoc(), "refbackend_consume_func_return",
auto consumeFuncReturnInt64Func = b.create<FuncOp>(
module.getLoc(), "refbackend_consume_int64_func_return",
FunctionType::get(
module.getContext(),
UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}),
b.getStringAttr("private"));
auto consumeFuncReturnFloat32Func = b.create<FuncOp>(
module.getLoc(), "refbackend_consume_float32_func_return",
FunctionType::get(
module.getContext(),
UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
addEmitCInterfaceAttr(consumeFuncReturnInt64Func);
addEmitCInterfaceAttr(consumeFuncReturnFloat32Func);
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func;
consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func;
for (auto func : module.getOps<FuncOp>()) {
if (func == consumeFuncReturnFunc)
if (func == consumeFuncReturnInt64Func ||
func == consumeFuncReturnFloat32Func)
continue;
if (failed(mungeFunction(func, consumeFuncReturnFunc)))
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
return signalPassFailure();
}
}
@ -160,7 +165,6 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
namespace {
class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();

View File

@ -36,11 +36,18 @@ class RefBackendInvoker:
self.result = None
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return(a):
def consume_i64_return(a):
self.result = unranked_memref_to_numpy(a, np.int64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_f32_return(a):
self.result = unranked_memref_to_numpy(a, np.float32)
self.ee.register_runtime("refbackend_consume_func_return",
consume_return)
self.ee.register_runtime("refbackend_consume_int64_func_return",
consume_i64_return)
self.ee.register_runtime("refbackend_consume_float32_func_return",
consume_f32_return)
def __getattr__(self, function_name: str):
def invoke(*args):

View File

@ -1,11 +1,23 @@
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions | FileCheck %s
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: call @refbackend_consume_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32>
}
// -----
// CHECK-LABEL: func @i(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64>
}