Reject dictionary inputs when tracing.

The underlying error message was misleading.  See https://github.com/llvm/torch-mlir/issues/1425
pull/1441/head snapshot-20221001.613
Daniel Ellis 2022-09-28 18:42:34 +00:00
parent b3345e69e2
commit 4d47f1671a
2 changed files with 32 additions and 3 deletions

View File

@ -9,9 +9,8 @@ import torch
import torch_mlir import torch_mlir
class TanhModule(torch.nn.Module): class TanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x): def forward(self, x):
return torch.ops.aten.tanh(x) return torch.ops.aten.tanh(x)
@ -52,3 +51,22 @@ try:
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True) torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True)
except Exception as e: except Exception as e:
print(e) print(e)
class DictModule(torch.nn.Module):
def forward(self, x):
return x['a'] * 2.0
try:
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
except Exception as e:
print(e)
try:
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
except Exception as e:
print(e)

View File

@ -176,6 +176,13 @@ def compile(model: torch.nn.Module,
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)): if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
example_args = (example_args,) example_args = (example_args,)
# If users passed in anything other than tensors or a list of tensors (e.g.
# a dictionary), we can't handle it.
if not isinstance(example_args, Sequence):
raise Exception(
"Only Tensors, TensorPlaceholders, or a sequences of Tensors and "
"TensorPlaceholders are supported as inputs.")
# TODO: Don't hardcode "forward". See `torch.onnx.export` and # TODO: Don't hardcode "forward". See `torch.onnx.export` and
# `torch.jit.trace_module` for API inspiration. # `torch.jit.trace_module` for API inspiration.
if use_tracing: if use_tracing:
@ -197,8 +204,12 @@ def compile(model: torch.nn.Module,
shape = [s if s != -1 else 7 for s in arg.shape] shape = [s if s != -1 else 7 for s in arg.shape]
example_args_for_trace.append( example_args_for_trace.append(
torch.ones(*shape, dtype=arg.dtype)) torch.ones(*shape, dtype=arg.dtype))
else: elif isinstance(arg, torch.Tensor):
example_args_for_trace.append(arg) example_args_for_trace.append(arg)
else:
raise Exception(
"Only Tensors, TensorPlaceholders, or a sequences of "
"Tensors and TensorPlaceholders are supported as inputs.")
scripted = torch.jit.trace(model, tuple(example_args_for_trace)) scripted = torch.jit.trace(model, tuple(example_args_for_trace))
else: else:
scripted = torch.jit.script(model) scripted = torch.jit.script(model)