Adding support for returning elemental types.

Support for returning elemental types. Previously, only
memref types as returning types was supported. All the hacky ways
to write tests which return elemental types should be taken care of.

Signed-off-by: Prashant Kumar <prashant@nod-labs.com>
pull/392/head
Prashant Kumar 2021-11-06 19:25:06 +00:00
parent b33543af85
commit fd505db2c6
4 changed files with 109 additions and 41 deletions

View File

@ -518,16 +518,13 @@ class TensorToInt(torch.nn.Module):
@annotate_args([
None,
([], torch.int64, True),
([], torch.float32, True),
])
def forward(self, x, y):
# This is a workaround for not returning scalar value.
a = int(x)
return y.add(y, alpha=a)
def forward(self, x):
return int(x)
@register_test_case(module_factory=lambda: TensorToInt())
def TensorToInt_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[]), tu.rand())
module.forward(torch.randint(10,[]))
class LogSoftmaxIntModule(torch.nn.Module):
def __init__(self):

View File

@ -64,6 +64,30 @@ static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
}
// Passes the return op operands `val` to `funOp`. Also, adds the op to the
// `toErase` vector.
static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp,
Value val,
SmallVectorImpl<Operation *> &toErase) {
b.create<mlir::CallOp>(op.getLoc(), funcOp, val);
b.create<mlir::ReturnOp>(op.getLoc());
toErase.push_back(op);
}
// Checks whether the return op is munge-compatible and the respective calling
// function is defined.
static bool isReturnOpCompatible(ReturnOp op,
DenseMap<Type, FuncOp> &consumeFuncReturnFuncs,
Type returnType) {
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
op.emitError("must have one return value of Memref type or Elemental types "
"of i64, f64, f32");
return false;
}
return true;
}
static LogicalResult mungeFunction(
FuncOp func,
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
@ -93,39 +117,41 @@ static LogicalResult mungeFunction(
}
SmallVector<Operation *> toErase;
bool hadError = false;
bool isCompatible = false;
func.walk([&](ReturnOp op) {
auto memRefType = op.getOperandTypes()[0].dyn_cast<MemRefType>();
if (!memRefType) {
hadError = true;
op.emitError("return value must be memref type");
return;
}
auto returnType = memRefType.getElementType();
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
hadError = true;
op.emitError("must have one return value: a memref of f32, i64 or f64");
return;
}
auto returnType = op.getOperandTypes()[0];
b.setInsertionPoint(op);
auto cast =
b.create<memref::CastOp>(op.getLoc(), op.getOperand(0),
// Memref Types.
if (auto memrefReturnType = returnType.dyn_cast<MemRefType>()) {
auto elemType = memrefReturnType.getElementType();
auto unRankedType = UnrankedMemRefType::get(elemType, 0);
isCompatible =
isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType);
if (!isCompatible)
return;
// Cast to unranked memref type before sending it as a function argument.
auto cast = b.create<memref::CastOp>(
op.getLoc(), op.getOperand(0),
getAbiTypeForMemRef(op.getOperandTypes()[0]));
b.create<mlir::CallOp>(op.getLoc(), consumeFuncReturnFuncs[returnType],
cast.getResult());
b.create<mlir::ReturnOp>(op.getLoc());
toErase.push_back(op);
replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType],
cast.getResult(), toErase);
// Elemental types.
} else if (returnType.isa<IntegerType>() || returnType.isa<FloatType>()) {
isCompatible =
isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType);
if (!isCompatible)
return;
replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType],
op->getOperand(0), toErase);
}
});
if (hadError)
if (!isCompatible)
return failure();
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
for (Operation *op : toErase)
op->erase();
return success();
}
@ -137,17 +163,28 @@ class MungeCallingConventions
OpBuilder b(module.getBodyRegion());
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
DenseSet<FuncOp> consumeFuncReturnFuncsSet;
auto createConsumeFuncReturnFunc = [&](Type elemTy, std::string funcName) {
auto createConsumeFuncReturnFunc = [&](Type returnType,
std::string funcName) {
auto consumeFuncReturnFunc = b.create<FuncOp>(
module.getLoc(), funcName,
FunctionType::get(module.getContext(),
UnrankedMemRefType::get(elemTy, /*memorySpace=*/0),
{}),
FunctionType::get(module.getContext(), returnType, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
consumeFuncReturnFuncs[elemTy] = consumeFuncReturnFunc;
consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc;
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
};
// Memref return types.
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
"refbackend_consume_memref_int64_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF32Type(), 0),
"refbackend_consume_memref_float32_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF64Type(), 0),
"refbackend_consume_memref_float64_func_return");
// Elemental return types.
createConsumeFuncReturnFunc(b.getI64Type(),
"refbackend_consume_int64_func_return");
createConsumeFuncReturnFunc(b.getF32Type(),

View File

@ -33,17 +33,38 @@ class RefBackendInvoker:
self.result = None
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_i64_return(a):
def consume_memref_i64_return(a):
self.result = unranked_memref_to_numpy(a, np.int64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_f32_return(a):
def consume_memref_f32_return(a):
self.result = unranked_memref_to_numpy(a, np.float32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_f64_return(a):
def consume_memref_f64_return(a):
self.result = unranked_memref_to_numpy(a, np.float64)
@ctypes.CFUNCTYPE(None, ctypes.c_int)
def consume_i64_return(a):
self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_float)
def consume_f32_return(a):
self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_double)
def consume_f64_return(a):
self.result = a
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
consume_memref_i64_return)
self.ee.register_runtime("refbackend_consume_memref_float32_func_return",
consume_memref_f32_return)
self.ee.register_runtime("refbackend_consume_memref_float64_func_return",
consume_memref_f64_return)
self.ee.register_runtime("refbackend_consume_int64_func_return",
consume_i64_return)

View File

@ -4,7 +4,7 @@
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: call @refbackend_consume_memref_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32>
@ -16,8 +16,21 @@ func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: call @refbackend_consume_memref_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64>
}
// -----
// CHECK-LABEL: func @elemental_type(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (i64) -> ()
// CHECK: return
func @elemental_type(%arg0: memref<i64>) -> i64 {
%0 = memref.load %arg0[] : memref<i64>
return %0 : i64
}