mirror of https://github.com/llvm/torch-mlir
add i64 support to refbackend
parent
fadd76e9b8
commit
2e1498ad11
|
@ -53,15 +53,6 @@ static bool isArgMemRefTypeValid(Type type) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isReturnMemRefTypeValid(Type type) {
|
|
||||||
if (auto memRefType = type.dyn_cast<MemRefType>()) {
|
|
||||||
if (memRefType.getElementType().isa<Float32Type>()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void addEmitCInterfaceAttr(FuncOp func) {
|
static void addEmitCInterfaceAttr(FuncOp func) {
|
||||||
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
||||||
}
|
}
|
||||||
|
@ -70,7 +61,9 @@ static Type getAbiTypeForMemRef(Type type) {
|
||||||
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
static LogicalResult mungeFunction(
|
||||||
|
FuncOp func,
|
||||||
|
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
|
||||||
// Add `llvm.emit_c_interface`.
|
// Add `llvm.emit_c_interface`.
|
||||||
// This allows ExecutionEngine to resolve the symbol properly.
|
// This allows ExecutionEngine to resolve the symbol properly.
|
||||||
addEmitCInterfaceAttr(func);
|
addEmitCInterfaceAttr(func);
|
||||||
|
@ -98,17 +91,20 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
||||||
SmallVector<Operation *> toErase;
|
SmallVector<Operation *> toErase;
|
||||||
bool hadError = false;
|
bool hadError = false;
|
||||||
func.walk([&](ReturnOp op) {
|
func.walk([&](ReturnOp op) {
|
||||||
if (op.getNumOperands() != 1 ||
|
auto returnType =
|
||||||
!isReturnMemRefTypeValid(op.getOperandTypes()[0])) {
|
op.getOperandTypes()[0].dyn_cast<MemRefType>().getElementType();
|
||||||
|
auto it = consumeFuncReturnFuncs.find(returnType);
|
||||||
|
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
|
||||||
hadError = true;
|
hadError = true;
|
||||||
op.emitError("must have one return value and it must be a memref of f32");
|
op.emitError("must have one return value: a memref of f32 or i64");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
b.setInsertionPoint(op);
|
b.setInsertionPoint(op);
|
||||||
auto cast =
|
auto cast =
|
||||||
b.create<memref::CastOp>(op.getLoc(), op.getOperand(0),
|
b.create<memref::CastOp>(op.getLoc(), op.getOperand(0),
|
||||||
getAbiTypeForMemRef(op.getOperandTypes()[0]));
|
getAbiTypeForMemRef(op.getOperandTypes()[0]));
|
||||||
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFunc,
|
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFuncs[returnType],
|
||||||
cast.getResult());
|
cast.getResult());
|
||||||
b.create<mlir::ReturnOp>(op.getLoc());
|
b.create<mlir::ReturnOp>(op.getLoc());
|
||||||
toErase.push_back(op);
|
toErase.push_back(op);
|
||||||
|
@ -127,22 +123,31 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
||||||
namespace {
|
namespace {
|
||||||
class MungeCallingConventions
|
class MungeCallingConventions
|
||||||
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
OpBuilder b(module.getBodyRegion());
|
OpBuilder b(module.getBodyRegion());
|
||||||
|
auto consumeFuncReturnInt64Func = b.create<FuncOp>(
|
||||||
auto consumeFuncReturnFunc = b.create<FuncOp>(
|
module.getLoc(), "refbackend_consume_int64_func_return",
|
||||||
module.getLoc(), "refbackend_consume_func_return",
|
FunctionType::get(
|
||||||
|
module.getContext(),
|
||||||
|
UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}),
|
||||||
|
b.getStringAttr("private"));
|
||||||
|
auto consumeFuncReturnFloat32Func = b.create<FuncOp>(
|
||||||
|
module.getLoc(), "refbackend_consume_float32_func_return",
|
||||||
FunctionType::get(
|
FunctionType::get(
|
||||||
module.getContext(),
|
module.getContext(),
|
||||||
UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}),
|
UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}),
|
||||||
b.getStringAttr("private"));
|
b.getStringAttr("private"));
|
||||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
addEmitCInterfaceAttr(consumeFuncReturnInt64Func);
|
||||||
|
addEmitCInterfaceAttr(consumeFuncReturnFloat32Func);
|
||||||
|
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
|
||||||
|
consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func;
|
||||||
|
consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func;
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<FuncOp>()) {
|
||||||
if (func == consumeFuncReturnFunc)
|
if (func == consumeFuncReturnInt64Func ||
|
||||||
|
func == consumeFuncReturnFloat32Func)
|
||||||
continue;
|
continue;
|
||||||
if (failed(mungeFunction(func, consumeFuncReturnFunc)))
|
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,7 +165,6 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
|
|
|
@ -36,11 +36,18 @@ class RefBackendInvoker:
|
||||||
self.result = None
|
self.result = None
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
def consume_return(a):
|
def consume_i64_return(a):
|
||||||
|
self.result = unranked_memref_to_numpy(a, np.int64)
|
||||||
|
|
||||||
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
|
def consume_f32_return(a):
|
||||||
self.result = unranked_memref_to_numpy(a, np.float32)
|
self.result = unranked_memref_to_numpy(a, np.float32)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_func_return",
|
self.ee.register_runtime("refbackend_consume_int64_func_return",
|
||||||
consume_return)
|
consume_i64_return)
|
||||||
|
|
||||||
|
self.ee.register_runtime("refbackend_consume_float32_func_return",
|
||||||
|
consume_f32_return)
|
||||||
|
|
||||||
def __getattr__(self, function_name: str):
|
def __getattr__(self, function_name: str):
|
||||||
def invoke(*args):
|
def invoke(*args):
|
||||||
|
|
|
@ -1,11 +1,23 @@
|
||||||
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions | FileCheck %s
|
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @f(
|
// CHECK-LABEL: func @f(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
|
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
|
||||||
// CHECK: call @refbackend_consume_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
|
// CHECK: call @refbackend_consume_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
return %arg0 : memref<?xf32>
|
return %arg0 : memref<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @i(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
||||||
|
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
|
||||||
|
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
|
||||||
|
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
|
||||||
|
// CHECK: return
|
||||||
|
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
|
||||||
|
return %arg0 : memref<?xi64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue