mirror of https://github.com/llvm/torch-mlir
[fx] Fix type determination for multi-return ops and static `None` returns. (#3081)
In practice, this was caught by the way that AOT autograd traces `convolution_backward`. For the unit test, we just repro it with a custom op.pull/3098/head
parent
129a79417a
commit
282e9b0e64
|
@ -354,14 +354,6 @@ def is_builtin_function_or_method(obj: Any) -> bool:
|
|||
class InputInfo:
|
||||
"""Provides additional metadata when resolving inputs."""
|
||||
|
||||
__slots__ = [
|
||||
"program",
|
||||
"input_spec",
|
||||
"node",
|
||||
"ir_type",
|
||||
"mutable_producer_node_name",
|
||||
]
|
||||
|
||||
program: torch.export.ExportedProgram
|
||||
input_spec: TypingInputSpec
|
||||
node: Node
|
||||
|
@ -915,6 +907,22 @@ class ContextCache:
|
|||
tensor_meta = node.meta.get("tensor_meta")
|
||||
val = node.meta.get("val")
|
||||
sparsity = node.meta.get("sparsity", None)
|
||||
return self.value_info_to_type(
|
||||
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||
)
|
||||
|
||||
def value_info_to_type(
|
||||
self,
|
||||
val,
|
||||
*,
|
||||
tensor_meta: Optional[TensorMetadata] = None,
|
||||
sparsity=None,
|
||||
mutable: bool = False,
|
||||
):
|
||||
if tensor_meta is not None:
|
||||
assert isinstance(tensor_meta, TensorMetadata)
|
||||
# Quantized tensor meta data is not preserved in our lowering,
|
||||
|
@ -934,18 +942,13 @@ class ContextCache:
|
|||
return self.get_vtensor_type(
|
||||
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
|
||||
else:
|
||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||
if t is not None:
|
||||
return IrType.parse(t, self._c)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"FIXME: Unsupported placeholder node (this often indicates that a necessary) "
|
||||
f"fx preprocessing pass was not run): {node.meta}"
|
||||
)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||
f"Could not deduce type from value info: "
|
||||
f"tensor_meta={tensor_meta}, val={val}, sparsity={sparsity}"
|
||||
)
|
||||
|
||||
def tensor_metadata_to_type(
|
||||
|
@ -1624,10 +1627,9 @@ class GraphNodeImporter:
|
|||
# short-circuit above. Note that if we ever choose to also fully reify Python
|
||||
# level result tuples, we will need to create a tuple-boxed version of this and
|
||||
# redirect to it for generic object access.
|
||||
|
||||
result_types = []
|
||||
for v in node.meta["val"]:
|
||||
result_types.append(self._cc.tensor_metadata_to_type(v))
|
||||
result_types.append(self._cc.value_info_to_type(v))
|
||||
result_types = tuple(result_types)
|
||||
return result_types
|
||||
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This file contains tests of various op special forms that the fx_importer
|
||||
# handles.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
|
||||
from torch_mlir import fx
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_lift_fresh_copy
|
||||
def test_lift_fresh_copy():
|
||||
#
|
||||
class Basic(nn.Module):
|
||||
def forward(self, x):
|
||||
# CHECK: torch.aten.clone %arg0, %none
|
||||
return torch.ops.aten.lift_fresh_copy.default(x)
|
||||
|
||||
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||
print(m)
|
|
@ -0,0 +1,74 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This file contains tests of various op special forms that the fx_importer
|
||||
# handles.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
|
||||
from torch_mlir import fx
|
||||
|
||||
################################################################################
|
||||
# Custom ops to test various things that are hard to reach.
|
||||
################################################################################
|
||||
LIBRARY = torch.library.Library("torch_mlir_test", "DEF")
|
||||
|
||||
LIBRARY.define("multi_return(Tensor x) -> (Tensor, Tensor, Tensor)")
|
||||
|
||||
|
||||
def multi_return_meta(x):
|
||||
return None, torch.empty_like(x), torch.empty_like(x)
|
||||
|
||||
|
||||
LIBRARY.impl("multi_return", multi_return_meta, "Meta")
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_lift_fresh_copy
|
||||
def test_lift_fresh_copy():
|
||||
class Basic(nn.Module):
|
||||
def forward(self, x):
|
||||
# CHECK: torch.aten.clone %arg0, %none
|
||||
return torch.ops.aten.lift_fresh_copy.default(x)
|
||||
|
||||
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_multi_return
|
||||
def test_multi_return():
|
||||
class Basic(nn.Module):
|
||||
def forward(self, x):
|
||||
# Note that optional return tensors that are statically traced to
|
||||
# None show up as a !torch.none type. This happens in the case of
|
||||
# certain convolution backwards ops (possibly among others).
|
||||
# The FxImporter does not perform special tracking of static None
|
||||
# values, instead just materializing a torch.constant.none when
|
||||
# needed. This is an implementation detail: it would be valid to
|
||||
# use the RES:0 result instead of this materialization below.
|
||||
# In practice, this doesn't arise in nature and is a by-product
|
||||
# of tracing.
|
||||
# CHECK: %[[RES:.*]]:3 = torch.operator "torch.torch_mlir_test.multi_return"(%arg0) :
|
||||
# CHECK-SAME: (!torch.vtensor<[3,4],f32>)
|
||||
# CHECK-SAME: -> (!torch.none, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>)
|
||||
# CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
# CHECK: return %[[NONE]], %[[RES]]#1, %[[RES]]#2
|
||||
return torch.ops.torch_mlir_test.multi_return(x)
|
||||
|
||||
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||
print(m)
|
Loading…
Reference in New Issue