mirror of https://github.com/llvm/torch-mlir
[fx] Fix type inference for scalar/int types. (#3099)
This was discovered in a downstream test suite and was due to a control flow nesting merge issue. In-tree test added and fixed.pull/3100/head
parent
40e762ca42
commit
ffaaf08c31
|
@ -927,13 +927,13 @@ class ContextCache:
|
|||
tensor_meta = node.meta.get("tensor_meta")
|
||||
val = node.meta.get("val")
|
||||
sparsity = node.meta.get("sparsity", None)
|
||||
return self.value_info_to_type(
|
||||
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||
)
|
||||
return self.value_info_to_type(
|
||||
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
|
||||
def value_info_to_type(
|
||||
self,
|
||||
|
@ -962,13 +962,16 @@ class ContextCache:
|
|||
return self.get_vtensor_type(
|
||||
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
else:
|
||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||
if t is not None:
|
||||
return IrType.parse(t, self._c)
|
||||
|
||||
# Note that None is a valid scalar here, so it is important that this
|
||||
# is always checked as the last fallback.
|
||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||
if t is not None:
|
||||
return IrType.parse(t, self._c)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Could not deduce type from value info: "
|
||||
f"tensor_meta={tensor_meta}, val={val}, sparsity={sparsity}"
|
||||
f"tensor_meta={tensor_meta}, val={val} {type(val)}, sparsity={sparsity}"
|
||||
)
|
||||
|
||||
def tensor_metadata_to_type(
|
||||
|
@ -1631,7 +1634,9 @@ class GraphNodeImporter:
|
|||
with loc:
|
||||
return cvt(arg, self, self._cc)
|
||||
|
||||
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]:
|
||||
def _unpack_node_result_types(
|
||||
self, node: torch.fx.Node, schema: FunctionSchema
|
||||
) -> List[IrType]:
|
||||
return_count = len(schema.returns)
|
||||
if return_count == 1:
|
||||
# Unary return directly maps a single meta["val"] and cannot be subscripted.
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This file contains tests of various op special forms that the fx_importer
|
||||
# handles.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
|
||||
from torch_mlir import fx
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_scalar_typed_node
|
||||
# Getting the shape of a dynamic dimension has the side effect of producing
|
||||
# a node like:
|
||||
# sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||
# This tests the fx_importer code paths around resolving scalar/symbolic
|
||||
# types for operands and results.
|
||||
def test_scalar_typed_node():
|
||||
class Basic(nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + 1.0
|
||||
return x.shape[0]
|
||||
|
||||
# CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int
|
||||
m = fx.export_and_import(
|
||||
Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}}
|
||||
)
|
||||
print(m)
|
Loading…
Reference in New Issue