[fx] Make the lift_fresh_copy -> clone special form use kwargs. (#3045)

At some point, this op became kwarg-only instead of arg/kwarg.
Discovered when upgrading to PyTorch 2.3.

Also adds a test as this was untested in-tree (was caught out of tree).
pull/3047/head
Stella Laurenzo 2024-03-21 15:34:40 -07:00 committed by GitHub
parent 7616d637fd
commit 6ea857c644
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 3 deletions

View File

@ -1276,10 +1276,16 @@ class GraphNodeImporter:
# replace lift_fresh_copy with clone op
if target == torch.ops.aten.lift_fresh_copy.default:
node.target = target = torch.ops.aten.clone.default
node.args = (node.args[0], None)
node.args = (node.args[0],)
node.kwargs = {"memory_format": None}
elif target == torch.ops.aten.lift_fresh_copy.out:
# TODO: It seems not possible to hit this case from user code.
# Retaining in case if it is triggered internally somehow, but
# it can most likely be removed once assuming full
# functionalization in all cases.
node.target = target = torch.ops.aten.clone.out
node.args = (node.args[0], None, node.args[1])
node.args = (node.args[0],)
node.kwargs = {"memory_format": None, "out": node.args[1]}
# TODO: generalize empty.memory_format in the future
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
# empty.memory_format input with a constant, which creates NaN values
@ -1664,7 +1670,8 @@ class TypeSubclassMap:
# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType: ...
class EmptyType:
...
Empty = EmptyType()

View File

@ -0,0 +1,36 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
# This file contains tests of various op special forms that the fx_importer
# handles.
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir import fx
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_lift_fresh_copy
def test_lift_fresh_copy():
#
class Basic(nn.Module):
def forward(self, x):
# CHECK: torch.aten.clone %arg0, %none
return torch.ops.aten.lift_fresh_copy.default(x)
m = fx.export_and_import(Basic(), torch.randn(3, 4))
print(m)