Add i64 tensor argument support and bring back GatherModule_basic

pull/328/head
Yi Zhang 2021-09-23 15:22:28 -04:00
parent 12d0fe7c85
commit c9cc4cb2e9
3 changed files with 41 additions and 14 deletions

View File

@ -63,7 +63,6 @@ def BmmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
# ==============================================================================
@ -209,6 +208,7 @@ class MaxPool2dModule(torch.nn.Module):
def MaxPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
class TransposeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -246,6 +246,7 @@ class TensorsConcatModule(torch.nn.Module):
def TensorsConcatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))
class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -260,6 +261,6 @@ class GatherModule(torch.nn.Module):
return torch.gather(tensor, 2, indices)
#@register_test_case(module_factory=lambda: GatherModule())
#def GatherModule_basic(module, tu: TestUtils):
# module.forward(tu.rand(2, 3, 4), torch.tensor([[[1,2,3],[1,2,3]]]))
@register_test_case(module_factory=lambda: GatherModule())
def GatherModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))

View File

@ -39,7 +39,20 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
// MungeCallingConventions
//===----------------------------------------------------------------------===//
static bool isF32MemRef(Type type) {
static bool isArgMemRefTypeValid(Type type) {
if (auto memRefType = type.dyn_cast<MemRefType>()) {
Type elemTy = memRefType.getElementType();
if (elemTy.isa<Float32Type>()) {
return true;
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
if (integerTy.isSignlessInteger(64))
return true;
}
}
return false;
}
static bool isReturnMemRefTypeValid(Type type) {
if (auto memRefType = type.dyn_cast<MemRefType>()) {
if (memRefType.getElementType().isa<Float32Type>()) {
return true;
@ -73,8 +86,8 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
SmallVector<Type> newArgTypes;
for (auto arg : func.getArguments()) {
auto type = arg.getType();
if (!isF32MemRef(type))
return emitError(arg.getLoc(), "argument must be a memref of f32");
if (!isArgMemRefTypeValid(type))
return emitError(arg.getLoc(), "argument must be a memref of f32 or i64");
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
arg.replaceAllUsesExcept(cast, cast);
arg.setType(getAbiTypeForMemRef(type));
@ -84,7 +97,8 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
SmallVector<Operation *> toErase;
bool hadError = false;
func.walk([&](ReturnOp op) {
if (op.getNumOperands() != 1 || !isF32MemRef(op.getOperandTypes()[0])) {
if (op.getNumOperands() != 1 ||
!isReturnMemRefTypeValid(op.getOperandTypes()[0])) {
hadError = true;
op.emitError("must have one return value and it must be a memref of f32");
return;

View File

@ -20,6 +20,15 @@ __all__ = [
]
def checkArgTypeIsSupported(ty):
if ty == np.float32:
return
elif ty == np.int64:
return
assert False, "Only tensor argument of float32 and int64 are supported but got " + str(
ty)
class RefBackendInvoker:
def __init__(self, module):
self.ee = ExecutionEngine(module)
@ -28,15 +37,19 @@ class RefBackendInvoker:
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return(a):
self.result = unranked_memref_to_numpy(a, np.float32)
self.ee.register_runtime("refbackend_consume_func_return", consume_return)
self.ee.register_runtime("refbackend_consume_func_return",
consume_return)
def __getattr__(self, function_name: str):
def invoke(*args):
ffi_args = [
ctypes.pointer(
ffi_args = []
for arg in args:
checkArgTypeIsSupported(arg.dtype)
ffi_args.append(
ctypes.pointer(
get_unranked_memref_descriptor(arg)))
for arg in args]
ctypes.pointer(get_unranked_memref_descriptor(arg))))
self.ee.invoke(function_name, *ffi_args)
result = self.result
assert result is not None, "Invocation didn't produce a result"
@ -76,7 +89,6 @@ LOWERING_PIPELINE = ",".join([
class RefBackendNpcompBackend(NpcompBackend):
"""Main entry-point for the backend."""
def __init__(self):
super().__init__()