[fx] Fix type inference for scalar/int types. (#3099)

This was discovered in a downstream test suite and was due to a control
flow nesting merge issue. In-tree test added and fixed.
pull/3100/head
Stella Laurenzo 2024-04-02 13:56:43 -07:00 committed by GitHub
parent 40e762ca42
commit ffaaf08c31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 9 deletions

View File

@ -927,13 +927,13 @@ class ContextCache:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparsity = node.meta.get("sparsity", None)
return self.value_info_to_type(
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
)
except KeyError as e:
raise RuntimeError(
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
)
return self.value_info_to_type(
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
)
def value_info_to_type(
self,
@ -962,13 +962,16 @@ class ContextCache:
return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
)
else:
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
if t is not None:
return IrType.parse(t, self._c)
# Note that None is a valid scalar here, so it is important that this
# is always checked as the last fallback.
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
if t is not None:
return IrType.parse(t, self._c)
raise NotImplementedError(
f"Could not deduce type from value info: "
f"tensor_meta={tensor_meta}, val={val}, sparsity={sparsity}"
f"tensor_meta={tensor_meta}, val={val} {type(val)}, sparsity={sparsity}"
)
def tensor_metadata_to_type(
@ -1631,7 +1634,9 @@ class GraphNodeImporter:
with loc:
return cvt(arg, self, self._cc)
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]:
def _unpack_node_result_types(
self, node: torch.fx.Node, schema: FunctionSchema
) -> List[IrType]:
return_count = len(schema.returns)
if return_count == 1:
# Unary return directly maps a single meta["val"] and cannot be subscripted.

View File

@ -0,0 +1,43 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
# This file contains tests of various op special forms that the fx_importer
# handles.
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir import fx
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_scalar_typed_node
# Getting the shape of a dynamic dimension has the side effect of producing
# a node like:
# sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
# This tests the fx_importer code paths around resolving scalar/symbolic
# types for operands and results.
def test_scalar_typed_node():
class Basic(nn.Module):
def forward(self, x):
x = x + 1.0
return x.shape[0]
# CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int
m = fx.export_and_import(
Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}}
)
print(m)