From ffaaf08c317ec845ba08398268281ee602b0e1ba Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 2 Apr 2024 13:56:43 -0700 Subject: [PATCH] [fx] Fix type inference for scalar/int types. (#3099) This was discovered in a downstream test suite and was due to a control flow nesting merge issue. In-tree test added and fixed. --- python/torch_mlir/extras/fx_importer.py | 23 +++++++----- test/python/fx_importer/v2.3/types_test.py | 43 ++++++++++++++++++++++ 2 files changed, 57 insertions(+), 9 deletions(-) create mode 100644 test/python/fx_importer/v2.3/types_test.py diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index edcf62c69..aee8251b0 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -927,13 +927,13 @@ class ContextCache: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") sparsity = node.meta.get("sparsity", None) - return self.value_info_to_type( - val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable - ) except KeyError as e: raise RuntimeError( f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" ) + return self.value_info_to_type( + val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable + ) def value_info_to_type( self, @@ -962,13 +962,16 @@ class ContextCache: return self.get_vtensor_type( val.size(), val.dtype, sparsity=sparsity, mutable=mutable ) - else: - t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) - if t is not None: - return IrType.parse(t, self._c) + + # Note that None is a valid scalar here, so it is important that this + # is always checked as the last fallback. + t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) + if t is not None: + return IrType.parse(t, self._c) + raise NotImplementedError( f"Could not deduce type from value info: " - f"tensor_meta={tensor_meta}, val={val}, sparsity={sparsity}" + f"tensor_meta={tensor_meta}, val={val} {type(val)}, sparsity={sparsity}" ) def tensor_metadata_to_type( @@ -1631,7 +1634,9 @@ class GraphNodeImporter: with loc: return cvt(arg, self, self._cc) - def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]: + def _unpack_node_result_types( + self, node: torch.fx.Node, schema: FunctionSchema + ) -> List[IrType]: return_count = len(schema.returns) if return_count == 1: # Unary return directly maps a single meta["val"] and cannot be subscripted. diff --git a/test/python/fx_importer/v2.3/types_test.py b/test/python/fx_importer/v2.3/types_test.py new file mode 100644 index 000000000..19dee8b7b --- /dev/null +++ b/test/python/fx_importer/v2.3/types_test.py @@ -0,0 +1,43 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests of various op special forms that the fx_importer +# handles. + +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_scalar_typed_node +# Getting the shape of a dynamic dimension has the side effect of producing +# a node like: +# sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0) +# This tests the fx_importer code paths around resolving scalar/symbolic +# types for operands and results. +def test_scalar_typed_node(): + class Basic(nn.Module): + def forward(self, x): + x = x + 1.0 + return x.shape[0] + + # CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int + m = fx.export_and_import( + Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}} + ) + print(m)