mirror of https://github.com/llvm/torch-mlir
Add i64 tensor argument support and bring back GatherModule_basic
parent
12d0fe7c85
commit
c9cc4cb2e9
|
@ -63,7 +63,6 @@ def BmmModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -209,6 +208,7 @@ class MaxPool2dModule(torch.nn.Module):
|
|||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
||||
|
||||
class TransposeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -246,6 +246,7 @@ class TensorsConcatModule(torch.nn.Module):
|
|||
def TensorsConcatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -260,6 +261,6 @@ class GatherModule(torch.nn.Module):
|
|||
return torch.gather(tensor, 2, indices)
|
||||
|
||||
|
||||
#@register_test_case(module_factory=lambda: GatherModule())
|
||||
#def GatherModule_basic(module, tu: TestUtils):
|
||||
# module.forward(tu.rand(2, 3, 4), torch.tensor([[[1,2,3],[1,2,3]]]))
|
||||
@register_test_case(module_factory=lambda: GatherModule())
|
||||
def GatherModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
|
||||
|
|
|
@ -39,7 +39,20 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
|
|||
// MungeCallingConventions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static bool isF32MemRef(Type type) {
|
||||
static bool isArgMemRefTypeValid(Type type) {
|
||||
if (auto memRefType = type.dyn_cast<MemRefType>()) {
|
||||
Type elemTy = memRefType.getElementType();
|
||||
if (elemTy.isa<Float32Type>()) {
|
||||
return true;
|
||||
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
||||
if (integerTy.isSignlessInteger(64))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isReturnMemRefTypeValid(Type type) {
|
||||
if (auto memRefType = type.dyn_cast<MemRefType>()) {
|
||||
if (memRefType.getElementType().isa<Float32Type>()) {
|
||||
return true;
|
||||
|
@ -73,8 +86,8 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
|||
SmallVector<Type> newArgTypes;
|
||||
for (auto arg : func.getArguments()) {
|
||||
auto type = arg.getType();
|
||||
if (!isF32MemRef(type))
|
||||
return emitError(arg.getLoc(), "argument must be a memref of f32");
|
||||
if (!isArgMemRefTypeValid(type))
|
||||
return emitError(arg.getLoc(), "argument must be a memref of f32 or i64");
|
||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
|
||||
arg.replaceAllUsesExcept(cast, cast);
|
||||
arg.setType(getAbiTypeForMemRef(type));
|
||||
|
@ -84,7 +97,8 @@ static LogicalResult mungeFunction(FuncOp func, FuncOp consumeFuncReturnFunc) {
|
|||
SmallVector<Operation *> toErase;
|
||||
bool hadError = false;
|
||||
func.walk([&](ReturnOp op) {
|
||||
if (op.getNumOperands() != 1 || !isF32MemRef(op.getOperandTypes()[0])) {
|
||||
if (op.getNumOperands() != 1 ||
|
||||
!isReturnMemRefTypeValid(op.getOperandTypes()[0])) {
|
||||
hadError = true;
|
||||
op.emitError("must have one return value and it must be a memref of f32");
|
||||
return;
|
||||
|
|
|
@ -20,6 +20,15 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def checkArgTypeIsSupported(ty):
|
||||
if ty == np.float32:
|
||||
return
|
||||
elif ty == np.int64:
|
||||
return
|
||||
assert False, "Only tensor argument of float32 and int64 are supported but got " + str(
|
||||
ty)
|
||||
|
||||
|
||||
class RefBackendInvoker:
|
||||
def __init__(self, module):
|
||||
self.ee = ExecutionEngine(module)
|
||||
|
@ -28,15 +37,19 @@ class RefBackendInvoker:
|
|||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_return(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.float32)
|
||||
self.ee.register_runtime("refbackend_consume_func_return", consume_return)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_func_return",
|
||||
consume_return)
|
||||
|
||||
def __getattr__(self, function_name: str):
|
||||
def invoke(*args):
|
||||
ffi_args = [
|
||||
ctypes.pointer(
|
||||
ffi_args = []
|
||||
for arg in args:
|
||||
checkArgTypeIsSupported(arg.dtype)
|
||||
ffi_args.append(
|
||||
ctypes.pointer(
|
||||
get_unranked_memref_descriptor(arg)))
|
||||
for arg in args]
|
||||
ctypes.pointer(get_unranked_memref_descriptor(arg))))
|
||||
|
||||
self.ee.invoke(function_name, *ffi_args)
|
||||
result = self.result
|
||||
assert result is not None, "Invocation didn't produce a result"
|
||||
|
@ -76,7 +89,6 @@ LOWERING_PIPELINE = ",".join([
|
|||
|
||||
class RefBackendNpcompBackend(NpcompBackend):
|
||||
"""Main entry-point for the backend."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
|
Loading…
Reference in New Issue