mirror of https://github.com/llvm/torch-mlir
[fx] Fix type determination for multi-return ops and static `None` returns. (#3081)
In practice, this was caught by the way that AOT autograd traces `convolution_backward`. For the unit test, we just repro it with a custom op.pull/3098/head
parent
129a79417a
commit
282e9b0e64
|
@ -354,14 +354,6 @@ def is_builtin_function_or_method(obj: Any) -> bool:
|
||||||
class InputInfo:
|
class InputInfo:
|
||||||
"""Provides additional metadata when resolving inputs."""
|
"""Provides additional metadata when resolving inputs."""
|
||||||
|
|
||||||
__slots__ = [
|
|
||||||
"program",
|
|
||||||
"input_spec",
|
|
||||||
"node",
|
|
||||||
"ir_type",
|
|
||||||
"mutable_producer_node_name",
|
|
||||||
]
|
|
||||||
|
|
||||||
program: torch.export.ExportedProgram
|
program: torch.export.ExportedProgram
|
||||||
input_spec: TypingInputSpec
|
input_spec: TypingInputSpec
|
||||||
node: Node
|
node: Node
|
||||||
|
@ -915,39 +907,50 @@ class ContextCache:
|
||||||
tensor_meta = node.meta.get("tensor_meta")
|
tensor_meta = node.meta.get("tensor_meta")
|
||||||
val = node.meta.get("val")
|
val = node.meta.get("val")
|
||||||
sparsity = node.meta.get("sparsity", None)
|
sparsity = node.meta.get("sparsity", None)
|
||||||
if tensor_meta is not None:
|
return self.value_info_to_type(
|
||||||
assert isinstance(tensor_meta, TensorMetadata)
|
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
||||||
# Quantized tensor meta data is not preserved in our lowering,
|
|
||||||
# so throw error instead of silently doing wrong thing.
|
|
||||||
if tensor_meta.is_quantized:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Quantized tensor meta data is not supported."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.tensor_metadata_to_type(
|
|
||||||
tensor_meta, sparsity=sparsity, mutable=mutable
|
|
||||||
)
|
|
||||||
elif val is not None:
|
|
||||||
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
|
||||||
# tensor_meta
|
|
||||||
if isinstance(val, TorchFakeTensor):
|
|
||||||
return self.get_vtensor_type(
|
|
||||||
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
|
||||||
)
|
|
||||||
|
|
||||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
|
||||||
if t is not None:
|
|
||||||
return IrType.parse(t, self._c)
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"FIXME: Unsupported placeholder node (this often indicates that a necessary) "
|
|
||||||
f"fx preprocessing pass was not run): {node.meta}"
|
|
||||||
)
|
)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def value_info_to_type(
|
||||||
|
self,
|
||||||
|
val,
|
||||||
|
*,
|
||||||
|
tensor_meta: Optional[TensorMetadata] = None,
|
||||||
|
sparsity=None,
|
||||||
|
mutable: bool = False,
|
||||||
|
):
|
||||||
|
if tensor_meta is not None:
|
||||||
|
assert isinstance(tensor_meta, TensorMetadata)
|
||||||
|
# Quantized tensor meta data is not preserved in our lowering,
|
||||||
|
# so throw error instead of silently doing wrong thing.
|
||||||
|
if tensor_meta.is_quantized:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Quantized tensor meta data is not supported."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.tensor_metadata_to_type(
|
||||||
|
tensor_meta, sparsity=sparsity, mutable=mutable
|
||||||
|
)
|
||||||
|
elif val is not None:
|
||||||
|
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
||||||
|
# tensor_meta
|
||||||
|
if isinstance(val, TorchFakeTensor):
|
||||||
|
return self.get_vtensor_type(
|
||||||
|
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||||
|
if t is not None:
|
||||||
|
return IrType.parse(t, self._c)
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Could not deduce type from value info: "
|
||||||
|
f"tensor_meta={tensor_meta}, val={val}, sparsity={sparsity}"
|
||||||
|
)
|
||||||
|
|
||||||
def tensor_metadata_to_type(
|
def tensor_metadata_to_type(
|
||||||
self,
|
self,
|
||||||
tm: TensorMetadata,
|
tm: TensorMetadata,
|
||||||
|
@ -1624,10 +1627,9 @@ class GraphNodeImporter:
|
||||||
# short-circuit above. Note that if we ever choose to also fully reify Python
|
# short-circuit above. Note that if we ever choose to also fully reify Python
|
||||||
# level result tuples, we will need to create a tuple-boxed version of this and
|
# level result tuples, we will need to create a tuple-boxed version of this and
|
||||||
# redirect to it for generic object access.
|
# redirect to it for generic object access.
|
||||||
|
|
||||||
result_types = []
|
result_types = []
|
||||||
for v in node.meta["val"]:
|
for v in node.meta["val"]:
|
||||||
result_types.append(self._cc.tensor_metadata_to_type(v))
|
result_types.append(self._cc.value_info_to_type(v))
|
||||||
result_types = tuple(result_types)
|
result_types = tuple(result_types)
|
||||||
return result_types
|
return result_types
|
||||||
|
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
|
||||||
# This file contains tests of various op special forms that the fx_importer
|
|
||||||
# handles.
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.export
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from torch_mlir import fx
|
|
||||||
|
|
||||||
|
|
||||||
def run(f):
|
|
||||||
print(f"{f.__name__}")
|
|
||||||
print("-" * len(f.__name__))
|
|
||||||
f()
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
@run
|
|
||||||
# CHECK-LABEL: test_lift_fresh_copy
|
|
||||||
def test_lift_fresh_copy():
|
|
||||||
#
|
|
||||||
class Basic(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
# CHECK: torch.aten.clone %arg0, %none
|
|
||||||
return torch.ops.aten.lift_fresh_copy.default(x)
|
|
||||||
|
|
||||||
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
|
||||||
print(m)
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
# This file contains tests of various op special forms that the fx_importer
|
||||||
|
# handles.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.export
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from torch_mlir import fx
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Custom ops to test various things that are hard to reach.
|
||||||
|
################################################################################
|
||||||
|
LIBRARY = torch.library.Library("torch_mlir_test", "DEF")
|
||||||
|
|
||||||
|
LIBRARY.define("multi_return(Tensor x) -> (Tensor, Tensor, Tensor)")
|
||||||
|
|
||||||
|
|
||||||
|
def multi_return_meta(x):
|
||||||
|
return None, torch.empty_like(x), torch.empty_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
LIBRARY.impl("multi_return", multi_return_meta, "Meta")
|
||||||
|
|
||||||
|
|
||||||
|
def run(f):
|
||||||
|
print(f"{f.__name__}")
|
||||||
|
print("-" * len(f.__name__))
|
||||||
|
f()
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_lift_fresh_copy
|
||||||
|
def test_lift_fresh_copy():
|
||||||
|
class Basic(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
# CHECK: torch.aten.clone %arg0, %none
|
||||||
|
return torch.ops.aten.lift_fresh_copy.default(x)
|
||||||
|
|
||||||
|
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||||
|
print(m)
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_multi_return
|
||||||
|
def test_multi_return():
|
||||||
|
class Basic(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
# Note that optional return tensors that are statically traced to
|
||||||
|
# None show up as a !torch.none type. This happens in the case of
|
||||||
|
# certain convolution backwards ops (possibly among others).
|
||||||
|
# The FxImporter does not perform special tracking of static None
|
||||||
|
# values, instead just materializing a torch.constant.none when
|
||||||
|
# needed. This is an implementation detail: it would be valid to
|
||||||
|
# use the RES:0 result instead of this materialization below.
|
||||||
|
# In practice, this doesn't arise in nature and is a by-product
|
||||||
|
# of tracing.
|
||||||
|
# CHECK: %[[RES:.*]]:3 = torch.operator "torch.torch_mlir_test.multi_return"(%arg0) :
|
||||||
|
# CHECK-SAME: (!torch.vtensor<[3,4],f32>)
|
||||||
|
# CHECK-SAME: -> (!torch.none, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>)
|
||||||
|
# CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
# CHECK: return %[[NONE]], %[[RES]]#1, %[[RES]]#2
|
||||||
|
return torch.ops.torch_mlir_test.multi_return(x)
|
||||||
|
|
||||||
|
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||||
|
print(m)
|
Loading…
Reference in New Issue