mirror of https://github.com/llvm/torch-mlir
Reject dictionary inputs when tracing.
The underlying error message was misleading. See https://github.com/llvm/torch-mlir/issues/1425pull/1441/head snapshot-20221001.613
parent
b3345e69e2
commit
4d47f1671a
|
@ -9,9 +9,8 @@ import torch
|
|||
|
||||
import torch_mlir
|
||||
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.tanh(x)
|
||||
|
||||
|
@ -52,3 +51,22 @@ try:
|
|||
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
class DictModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x['a'] * 2.0
|
||||
|
||||
|
||||
try:
|
||||
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
|
||||
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
try:
|
||||
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
|
||||
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
|
@ -175,6 +175,13 @@ def compile(model: torch.nn.Module,
|
|||
# tensor to a list of a single tensor to make the API more ergonomic.
|
||||
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
|
||||
example_args = (example_args,)
|
||||
|
||||
# If users passed in anything other than tensors or a list of tensors (e.g.
|
||||
# a dictionary), we can't handle it.
|
||||
if not isinstance(example_args, Sequence):
|
||||
raise Exception(
|
||||
"Only Tensors, TensorPlaceholders, or a sequences of Tensors and "
|
||||
"TensorPlaceholders are supported as inputs.")
|
||||
|
||||
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
|
||||
# `torch.jit.trace_module` for API inspiration.
|
||||
|
@ -197,8 +204,12 @@ def compile(model: torch.nn.Module,
|
|||
shape = [s if s != -1 else 7 for s in arg.shape]
|
||||
example_args_for_trace.append(
|
||||
torch.ones(*shape, dtype=arg.dtype))
|
||||
else:
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
example_args_for_trace.append(arg)
|
||||
else:
|
||||
raise Exception(
|
||||
"Only Tensors, TensorPlaceholders, or a sequences of "
|
||||
"Tensors and TensorPlaceholders are supported as inputs.")
|
||||
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
|
||||
else:
|
||||
scripted = torch.jit.script(model)
|
||||
|
|
Loading…
Reference in New Issue