From 6ea857c644ca9a65a2a117a4b633c3d5d9a818cf Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 21 Mar 2024 15:34:40 -0700 Subject: [PATCH] [fx] Make the lift_fresh_copy -> clone special form use kwargs. (#3045) At some point, this op became kwarg-only instead of arg/kwarg. Discovered when upgrading to PyTorch 2.3. Also adds a test as this was untested in-tree (was caught out of tree). --- python/torch_mlir/extras/fx_importer.py | 13 +++++-- test/python/fx_importer/special_forms_test.py | 36 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 test/python/fx_importer/special_forms_test.py diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index ac4a04cfa..18f2572b8 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1276,10 +1276,16 @@ class GraphNodeImporter: # replace lift_fresh_copy with clone op if target == torch.ops.aten.lift_fresh_copy.default: node.target = target = torch.ops.aten.clone.default - node.args = (node.args[0], None) + node.args = (node.args[0],) + node.kwargs = {"memory_format": None} elif target == torch.ops.aten.lift_fresh_copy.out: + # TODO: It seems not possible to hit this case from user code. + # Retaining in case if it is triggered internally somehow, but + # it can most likely be removed once assuming full + # functionalization in all cases. node.target = target = torch.ops.aten.clone.out - node.args = (node.args[0], None, node.args[1]) + node.args = (node.args[0],) + node.kwargs = {"memory_format": None, "out": node.args[1]} # TODO: generalize empty.memory_format in the future # Currently, the aten.baddbmm.default op for Unet includes multiplying an # empty.memory_format input with a constant, which creates NaN values @@ -1664,7 +1670,8 @@ class TypeSubclassMap: # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: ... +class EmptyType: + ... Empty = EmptyType() diff --git a/test/python/fx_importer/special_forms_test.py b/test/python/fx_importer/special_forms_test.py new file mode 100644 index 000000000..2f6cf4912 --- /dev/null +++ b/test/python/fx_importer/special_forms_test.py @@ -0,0 +1,36 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests of various op special forms that the fx_importer +# handles. + +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_lift_fresh_copy +def test_lift_fresh_copy(): + # + class Basic(nn.Module): + def forward(self, x): + # CHECK: torch.aten.clone %arg0, %none + return torch.ops.aten.lift_fresh_copy.default(x) + + m = fx.export_and_import(Basic(), torch.randn(3, 4)) + print(m)