mirror of https://github.com/llvm/torch-mlir
Reject dictionary inputs when tracing.
The underlying error message was misleading. See https://github.com/llvm/torch-mlir/issues/1425pull/1441/head snapshot-20221001.613
parent
b3345e69e2
commit
4d47f1671a
|
@ -9,9 +9,8 @@ import torch
|
||||||
|
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
|
|
||||||
class TanhModule(torch.nn.Module):
|
class TanhModule(torch.nn.Module):
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.tanh(x)
|
return torch.ops.aten.tanh(x)
|
||||||
|
|
||||||
|
@ -52,3 +51,22 @@ try:
|
||||||
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True)
|
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
class DictModule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x['a'] * 2.0
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
|
||||||
|
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
|
||||||
|
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
|
@ -176,6 +176,13 @@ def compile(model: torch.nn.Module,
|
||||||
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
|
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
|
||||||
example_args = (example_args,)
|
example_args = (example_args,)
|
||||||
|
|
||||||
|
# If users passed in anything other than tensors or a list of tensors (e.g.
|
||||||
|
# a dictionary), we can't handle it.
|
||||||
|
if not isinstance(example_args, Sequence):
|
||||||
|
raise Exception(
|
||||||
|
"Only Tensors, TensorPlaceholders, or a sequences of Tensors and "
|
||||||
|
"TensorPlaceholders are supported as inputs.")
|
||||||
|
|
||||||
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
|
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
|
||||||
# `torch.jit.trace_module` for API inspiration.
|
# `torch.jit.trace_module` for API inspiration.
|
||||||
if use_tracing:
|
if use_tracing:
|
||||||
|
@ -197,8 +204,12 @@ def compile(model: torch.nn.Module,
|
||||||
shape = [s if s != -1 else 7 for s in arg.shape]
|
shape = [s if s != -1 else 7 for s in arg.shape]
|
||||||
example_args_for_trace.append(
|
example_args_for_trace.append(
|
||||||
torch.ones(*shape, dtype=arg.dtype))
|
torch.ones(*shape, dtype=arg.dtype))
|
||||||
else:
|
elif isinstance(arg, torch.Tensor):
|
||||||
example_args_for_trace.append(arg)
|
example_args_for_trace.append(arg)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Only Tensors, TensorPlaceholders, or a sequences of "
|
||||||
|
"Tensors and TensorPlaceholders are supported as inputs.")
|
||||||
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
|
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
|
||||||
else:
|
else:
|
||||||
scripted = torch.jit.script(model)
|
scripted = torch.jit.script(model)
|
||||||
|
|
Loading…
Reference in New Issue