add i64 support to refbackend

pull/351/head
dan 2021-10-05 02:06:59 +00:00 committed by Yi Zhang
parent fadd76e9b8
commit 2e1498ad11
3 changed files with 50 additions and 27 deletions

View File

@ -53,15 +53,6 @@ static bool isArgMemRefTypeValid(Type type) {
return false; return false;
} }
static bool isReturnMemRefTypeValid(Type type) {
if (auto memRefType = type.dyn_cast<MemRefType>()) {
if (memRefType.getElementType().isa<Float32Type>()) {
return true;
}
}
return false;
}
static void addEmitCInterfaceAttr(FuncOp func) { static void addEmitCInterfaceAttr(FuncOp func) {
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext())); func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
} }
@ -70,7 +61,9 @@ static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0); return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
} }
static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) { static LogicalResult mungeFunction(
FuncOp func,
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
// Add `llvm.emit_c_interface`. // Add `llvm.emit_c_interface`.
// This allows ExecutionEngine to resolve the symbol properly. // This allows ExecutionEngine to resolve the symbol properly.
addEmitCInterfaceAttr(func); addEmitCInterfaceAttr(func);
@ -98,17 +91,20 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
SmallVector<Operation *> toErase; SmallVector<Operation *> toErase;
bool hadError = false; bool hadError = false;
func.walk([&](ReturnOp op) { func.walk([&](ReturnOp op) {
if (op.getNumOperands() != 1 || auto returnType =
!isReturnMemRefTypeValid(op.getOperandTypes()[0])) { op.getOperandTypes()[0].dyn_cast<MemRefType>().getElementType();
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
hadError = true; hadError = true;
op.emitError("must have one return value and it must be a memref of f32"); op.emitError("must have one return value: a memref of f32 or i64");
return; return;
} }
b.setInsertionPoint(op); b.setInsertionPoint(op);
auto cast = auto cast =
b.create<memref::CastOp>(op.getLoc(), op.getOperand(0), b.create<memref::CastOp>(op.getLoc(), op.getOperand(0),
getAbiTypeForMemRef(op.getOperandTypes()[0])); getAbiTypeForMemRef(op.getOperandTypes()[0]));
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFunc, b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFuncs[returnType],
cast.getResult()); cast.getResult());
b.create<mlir::ReturnOp>(op.getLoc()); b.create<mlir::ReturnOp>(op.getLoc());
toErase.push_back(op); toErase.push_back(op);
@ -127,22 +123,31 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
namespace { namespace {
class MungeCallingConventions class MungeCallingConventions
: public MungeCallingConventionsBase<MungeCallingConventions> { : public MungeCallingConventionsBase<MungeCallingConventions> {
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
OpBuilder b(module.getBodyRegion()); OpBuilder b(module.getBodyRegion());
auto consumeFuncReturnInt64Func = b.create<FuncOp>(
auto consumeFuncReturnFunc = b.create<FuncOp>( module.getLoc(), "refbackend_consume_int64_func_return",
module.getLoc(), "refbackend_consume_func_return", FunctionType::get(
module.getContext(),
UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}),
b.getStringAttr("private"));
auto consumeFuncReturnFloat32Func = b.create<FuncOp>(
module.getLoc(), "refbackend_consume_float32_func_return",
FunctionType::get( FunctionType::get(
module.getContext(), module.getContext(),
UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}), UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}),
b.getStringAttr("private")); b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc); addEmitCInterfaceAttr(consumeFuncReturnInt64Func);
addEmitCInterfaceAttr(consumeFuncReturnFloat32Func);
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func;
consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func;
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<FuncOp>()) {
if (func == consumeFuncReturnFunc) if (func == consumeFuncReturnInt64Func ||
func == consumeFuncReturnFloat32Func)
continue; continue;
if (failed(mungeFunction(func, consumeFuncReturnFunc))) if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
return signalPassFailure(); return signalPassFailure();
} }
} }
@ -160,7 +165,6 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
namespace { namespace {
class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> { class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
void runOnOperation() override { void runOnOperation() override {
auto func = getOperation(); auto func = getOperation();
auto *context = &getContext(); auto *context = &getContext();

View File

@ -36,11 +36,18 @@ class RefBackendInvoker:
self.result = None self.result = None
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return(a): def consume_i64_return(a):
self.result = unranked_memref_to_numpy(a, np.int64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_f32_return(a):
self.result = unranked_memref_to_numpy(a, np.float32) self.result = unranked_memref_to_numpy(a, np.float32)
self.ee.register_runtime("refbackend_consume_func_return", self.ee.register_runtime("refbackend_consume_int64_func_return",
consume_return) consume_i64_return)
self.ee.register_runtime("refbackend_consume_float32_func_return",
consume_f32_return)
def __getattr__(self, function_name: str): def __getattr__(self, function_name: str):
def invoke(*args): def invoke(*args):

View File

@ -1,11 +1,23 @@
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions | FileCheck %s // RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
// CHECK-LABEL: func @f( // CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32> // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_func_return(%[[RESULT]]) : (memref<*xf32>) -> () // CHECK: call @refbackend_consume_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return // CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> { func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32> return %arg0 : memref<?xf32>
} }
// -----
// CHECK-LABEL: func @i(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64>
}