mirror of https://github.com/llvm/torch-mlir
[fx] Make the lift_fresh_copy -> clone special form use kwargs. (#3045)
At some point, this op became kwarg-only instead of arg/kwarg. Discovered when upgrading to PyTorch 2.3. Also adds a test as this was untested in-tree (was caught out of tree).pull/3047/head
parent
7616d637fd
commit
6ea857c644
|
@ -1276,10 +1276,16 @@ class GraphNodeImporter:
|
|||
# replace lift_fresh_copy with clone op
|
||||
if target == torch.ops.aten.lift_fresh_copy.default:
|
||||
node.target = target = torch.ops.aten.clone.default
|
||||
node.args = (node.args[0], None)
|
||||
node.args = (node.args[0],)
|
||||
node.kwargs = {"memory_format": None}
|
||||
elif target == torch.ops.aten.lift_fresh_copy.out:
|
||||
# TODO: It seems not possible to hit this case from user code.
|
||||
# Retaining in case if it is triggered internally somehow, but
|
||||
# it can most likely be removed once assuming full
|
||||
# functionalization in all cases.
|
||||
node.target = target = torch.ops.aten.clone.out
|
||||
node.args = (node.args[0], None, node.args[1])
|
||||
node.args = (node.args[0],)
|
||||
node.kwargs = {"memory_format": None, "out": node.args[1]}
|
||||
# TODO: generalize empty.memory_format in the future
|
||||
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
|
||||
# empty.memory_format input with a constant, which creates NaN values
|
||||
|
@ -1664,7 +1670,8 @@ class TypeSubclassMap:
|
|||
|
||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||
# may have a different meaning.
|
||||
class EmptyType: ...
|
||||
class EmptyType:
|
||||
...
|
||||
|
||||
|
||||
Empty = EmptyType()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This file contains tests of various op special forms that the fx_importer
|
||||
# handles.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
|
||||
from torch_mlir import fx
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_lift_fresh_copy
|
||||
def test_lift_fresh_copy():
|
||||
#
|
||||
class Basic(nn.Module):
|
||||
def forward(self, x):
|
||||
# CHECK: torch.aten.clone %arg0, %none
|
||||
return torch.ops.aten.lift_fresh_copy.default(x)
|
||||
|
||||
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||
print(m)
|
Loading…
Reference in New Issue