mirror of https://github.com/llvm/torch-mlir
add i64 support to refbackend
parent
fadd76e9b8
commit
2e1498ad11
|
@ -53,15 +53,6 @@ static bool isArgMemRefTypeValid(Type type) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static bool isReturnMemRefTypeValid(Type type) {
|
||||
if (auto memRefType = type.dyn_cast<MemRefType>()) {
|
||||
if (memRefType.getElementType().isa<Float32Type>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void addEmitCInterfaceAttr(FuncOp func) {
|
||||
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
||||
}
|
||||
|
@ -70,7 +61,9 @@ static Type getAbiTypeForMemRef(Type type) {
|
|||
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
||||
}
|
||||
|
||||
static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
||||
static LogicalResult mungeFunction(
|
||||
FuncOp func,
|
||||
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
|
||||
// Add `llvm.emit_c_interface`.
|
||||
// This allows ExecutionEngine to resolve the symbol properly.
|
||||
addEmitCInterfaceAttr(func);
|
||||
|
@ -98,17 +91,20 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
|||
SmallVector<Operation *> toErase;
|
||||
bool hadError = false;
|
||||
func.walk([&](ReturnOp op) {
|
||||
if (op.getNumOperands() != 1 ||
|
||||
!isReturnMemRefTypeValid(op.getOperandTypes()[0])) {
|
||||
auto returnType =
|
||||
op.getOperandTypes()[0].dyn_cast<MemRefType>().getElementType();
|
||||
auto it = consumeFuncReturnFuncs.find(returnType);
|
||||
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
|
||||
hadError = true;
|
||||
op.emitError("must have one return value and it must be a memref of f32");
|
||||
op.emitError("must have one return value: a memref of f32 or i64");
|
||||
return;
|
||||
}
|
||||
|
||||
b.setInsertionPoint(op);
|
||||
auto cast =
|
||||
b.create<memref::CastOp>(op.getLoc(), op.getOperand(0),
|
||||
getAbiTypeForMemRef(op.getOperandTypes()[0]));
|
||||
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFunc,
|
||||
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFuncs[returnType],
|
||||
cast.getResult());
|
||||
b.create<mlir::ReturnOp>(op.getLoc());
|
||||
toErase.push_back(op);
|
||||
|
@ -127,22 +123,31 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
|||
namespace {
|
||||
class MungeCallingConventions
|
||||
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
|
||||
auto consumeFuncReturnFunc = b.create<FuncOp>(
|
||||
module.getLoc(), "refbackend_consume_func_return",
|
||||
auto consumeFuncReturnInt64Func = b.create<FuncOp>(
|
||||
module.getLoc(), "refbackend_consume_int64_func_return",
|
||||
FunctionType::get(
|
||||
module.getContext(),
|
||||
UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}),
|
||||
b.getStringAttr("private"));
|
||||
auto consumeFuncReturnFloat32Func = b.create<FuncOp>(
|
||||
module.getLoc(), "refbackend_consume_float32_func_return",
|
||||
FunctionType::get(
|
||||
module.getContext(),
|
||||
UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}),
|
||||
b.getStringAttr("private"));
|
||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
||||
addEmitCInterfaceAttr(consumeFuncReturnInt64Func);
|
||||
addEmitCInterfaceAttr(consumeFuncReturnFloat32Func);
|
||||
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
|
||||
consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func;
|
||||
consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func;
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (func == consumeFuncReturnFunc)
|
||||
if (func == consumeFuncReturnInt64Func ||
|
||||
func == consumeFuncReturnFloat32Func)
|
||||
continue;
|
||||
if (failed(mungeFunction(func, consumeFuncReturnFunc)))
|
||||
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -160,7 +165,6 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
|
|||
|
||||
namespace {
|
||||
class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
|
|
@ -36,11 +36,18 @@ class RefBackendInvoker:
|
|||
self.result = None
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_return(a):
|
||||
def consume_i64_return(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.int64)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_f32_return(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.float32)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_func_return",
|
||||
consume_return)
|
||||
self.ee.register_runtime("refbackend_consume_int64_func_return",
|
||||
consume_i64_return)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_float32_func_return",
|
||||
consume_f32_return)
|
||||
|
||||
def __getattr__(self, function_name: str):
|
||||
def invoke(*args):
|
||||
|
|
|
@ -1,11 +1,23 @@
|
|||
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions | FileCheck %s
|
||||
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK: call @refbackend_consume_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
|
||||
// CHECK: call @refbackend_consume_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
|
||||
// CHECK: return
|
||||
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
return %arg0 : memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @i(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
|
||||
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
|
||||
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
|
||||
// CHECK: return
|
||||
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
|
||||
return %arg0 : memref<?xi64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue