[fx] Fix type determination for multi-return ops and static `None` returns. (#3081)

In practice, this was caught by the way that AOT autograd traces
`convolution_backward`. For the unit test, we just repro it with a
custom op.
pull/3098/head
Stella Laurenzo 2024-04-01 09:39:38 -07:00 committed by GitHub
parent 129a79417a
commit 282e9b0e64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 113 additions and 73 deletions

View File

@ -354,14 +354,6 @@ def is_builtin_function_or_method(obj: Any) -> bool:
class InputInfo:
"""Provides additional metadata when resolving inputs."""
__slots__ = [
"program",
"input_spec",
"node",
"ir_type",
"mutable_producer_node_name",
]
program: torch.export.ExportedProgram
input_spec: TypingInputSpec
node: Node
@ -915,6 +907,22 @@ class ContextCache:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparsity = node.meta.get("sparsity", None)
return self.value_info_to_type(
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
)
except KeyError as e:
raise RuntimeError(
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
)
def value_info_to_type(
self,
val,
*,
tensor_meta: Optional[TensorMetadata] = None,
sparsity=None,
mutable: bool = False,
):
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# Quantized tensor meta data is not preserved in our lowering,
@ -934,18 +942,13 @@ class ContextCache:
return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
)
else:
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
if t is not None:
return IrType.parse(t, self._c)
raise NotImplementedError(
f"FIXME: Unsupported placeholder node (this often indicates that a necessary) "
f"fx preprocessing pass was not run): {node.meta}"
)
except KeyError as e:
raise RuntimeError(
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
f"Could not deduce type from value info: "
f"tensor_meta={tensor_meta}, val={val}, sparsity={sparsity}"
)
def tensor_metadata_to_type(
@ -1624,10 +1627,9 @@ class GraphNodeImporter:
# short-circuit above. Note that if we ever choose to also fully reify Python
# level result tuples, we will need to create a tuple-boxed version of this and
# redirect to it for generic object access.
result_types = []
for v in node.meta["val"]:
result_types.append(self._cc.tensor_metadata_to_type(v))
result_types.append(self._cc.value_info_to_type(v))
result_types = tuple(result_types)
return result_types

View File

@ -1,36 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
# This file contains tests of various op special forms that the fx_importer
# handles.
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir import fx
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_lift_fresh_copy
def test_lift_fresh_copy():
#
class Basic(nn.Module):
def forward(self, x):
# CHECK: torch.aten.clone %arg0, %none
return torch.ops.aten.lift_fresh_copy.default(x)
m = fx.export_and_import(Basic(), torch.randn(3, 4))
print(m)

View File

@ -0,0 +1,74 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
# This file contains tests of various op special forms that the fx_importer
# handles.
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir import fx
################################################################################
# Custom ops to test various things that are hard to reach.
################################################################################
LIBRARY = torch.library.Library("torch_mlir_test", "DEF")
LIBRARY.define("multi_return(Tensor x) -> (Tensor, Tensor, Tensor)")
def multi_return_meta(x):
return None, torch.empty_like(x), torch.empty_like(x)
LIBRARY.impl("multi_return", multi_return_meta, "Meta")
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_lift_fresh_copy
def test_lift_fresh_copy():
class Basic(nn.Module):
def forward(self, x):
# CHECK: torch.aten.clone %arg0, %none
return torch.ops.aten.lift_fresh_copy.default(x)
m = fx.export_and_import(Basic(), torch.randn(3, 4))
print(m)
@run
# CHECK-LABEL: test_multi_return
def test_multi_return():
class Basic(nn.Module):
def forward(self, x):
# Note that optional return tensors that are statically traced to
# None show up as a !torch.none type. This happens in the case of
# certain convolution backwards ops (possibly among others).
# The FxImporter does not perform special tracking of static None
# values, instead just materializing a torch.constant.none when
# needed. This is an implementation detail: it would be valid to
# use the RES:0 result instead of this materialization below.
# In practice, this doesn't arise in nature and is a by-product
# of tracing.
# CHECK: %[[RES:.*]]:3 = torch.operator "torch.torch_mlir_test.multi_return"(%arg0) :
# CHECK-SAME: (!torch.vtensor<[3,4],f32>)
# CHECK-SAME: -> (!torch.none, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>)
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: return %[[NONE]], %[[RES]]#1, %[[RES]]#2
return torch.ops.torch_mlir_test.multi_return(x)
m = fx.export_and_import(Basic(), torch.randn(3, 4))
print(m)