From 2e1498ad118d7c38ffb28f94e3f1f2c13f619ce1 Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 5 Oct 2021 02:06:59 +0000 Subject: [PATCH] add i64 support to refbackend --- lib/RefBackend/RefBackend.cpp | 48 ++++++++++--------- .../linalg_on_tensors_backends/refbackend.py | 13 +++-- .../RefBackend/munge-calling-conventions.mlir | 16 ++++++- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 406acb078..1a1fb5384 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -53,15 +53,6 @@ static bool isArgMemRefTypeValid(Type type) { return false; } -static bool isReturnMemRefTypeValid(Type type) { - if (auto memRefType = type.dyn_cast()) { - if (memRefType.getElementType().isa()) { - return true; - } - } - return false; -} - static void addEmitCInterfaceAttr(FuncOp func) { func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext())); } @@ -70,7 +61,9 @@ static Type getAbiTypeForMemRef(Type type) { return UnrankedMemRefType::get(type.cast().getElementType(), 0); } -static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) { +static LogicalResult mungeFunction( + FuncOp func, + DenseMap consumeFuncReturnFuncs) { // Add `llvm.emit_c_interface`. // This allows ExecutionEngine to resolve the symbol properly. addEmitCInterfaceAttr(func); @@ -98,17 +91,20 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) { SmallVector toErase; bool hadError = false; func.walk([&](ReturnOp op) { - if (op.getNumOperands() != 1 || - !isReturnMemRefTypeValid(op.getOperandTypes()[0])) { + auto returnType = + op.getOperandTypes()[0].dyn_cast().getElementType(); + auto it = consumeFuncReturnFuncs.find(returnType); + if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) { hadError = true; - op.emitError("must have one return value and it must be a memref of f32"); + op.emitError("must have one return value: a memref of f32 or i64"); return; } + b.setInsertionPoint(op); auto cast = b.create(op.getLoc(), op.getOperand(0), getAbiTypeForMemRef(op.getOperandTypes()[0])); - b.create(op.getLoc(), consumeFuncReturnFunc, + b.create(op.getLoc(), consumeFuncReturnFuncs[returnType], cast.getResult()); b.create(op.getLoc()); toErase.push_back(op); @@ -127,22 +123,31 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) { namespace { class MungeCallingConventions : public MungeCallingConventionsBase { - void runOnOperation() override { auto module = getOperation(); OpBuilder b(module.getBodyRegion()); - - auto consumeFuncReturnFunc = b.create( - module.getLoc(), "refbackend_consume_func_return", + auto consumeFuncReturnInt64Func = b.create( + module.getLoc(), "refbackend_consume_int64_func_return", + FunctionType::get( + module.getContext(), + UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}), + b.getStringAttr("private")); + auto consumeFuncReturnFloat32Func = b.create( + module.getLoc(), "refbackend_consume_float32_func_return", FunctionType::get( module.getContext(), UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}), b.getStringAttr("private")); - addEmitCInterfaceAttr(consumeFuncReturnFunc); + addEmitCInterfaceAttr(consumeFuncReturnInt64Func); + addEmitCInterfaceAttr(consumeFuncReturnFloat32Func); + DenseMap consumeFuncReturnFuncs; + consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func; + consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func; for (auto func : module.getOps()) { - if (func == consumeFuncReturnFunc) + if (func == consumeFuncReturnInt64Func || + func == consumeFuncReturnFloat32Func) continue; - if (failed(mungeFunction(func, consumeFuncReturnFunc))) + if (failed(mungeFunction(func, consumeFuncReturnFuncs))) return signalPassFailure(); } } @@ -160,7 +165,6 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() { namespace { class ExpandOpsForLLVM : public ExpandOpsForLLVMBase { - void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 3e3f7030c..a1e31a971 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -36,11 +36,18 @@ class RefBackendInvoker: self.result = None @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) - def consume_return(a): + def consume_i64_return(a): + self.result = unranked_memref_to_numpy(a, np.int64) + + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def consume_f32_return(a): self.result = unranked_memref_to_numpy(a, np.float32) - self.ee.register_runtime("refbackend_consume_func_return", - consume_return) + self.ee.register_runtime("refbackend_consume_int64_func_return", + consume_i64_return) + + self.ee.register_runtime("refbackend_consume_float32_func_return", + consume_f32_return) def __getattr__(self, function_name: str): def invoke(*args): diff --git a/test/RefBackend/munge-calling-conventions.mlir b/test/RefBackend/munge-calling-conventions.mlir index 07a569650..ccb0e434e 100644 --- a/test/RefBackend/munge-calling-conventions.mlir +++ b/test/RefBackend/munge-calling-conventions.mlir @@ -1,11 +1,23 @@ -// RUN: torch-mlir-opt %s -refback-munge-calling-conventions | FileCheck %s +// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s // CHECK-LABEL: func @f( // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref to memref<*xf32> -// CHECK: call @refbackend_consume_func_return(%[[RESULT]]) : (memref<*xf32>) -> () +// CHECK: call @refbackend_consume_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> () // CHECK: return func @f(%arg0: memref) -> memref { return %arg0 : memref } + +// ----- + +// CHECK-LABEL: func @i( +// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { +// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref +// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref to memref<*xi64> +// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> () +// CHECK: return +func @i(%arg0: memref) -> memref { + return %arg0 : memref +}