mirror of https://github.com/llvm/torch-mlir
[fx] Fix type inference for scalar/int types. (#3099)
This was discovered in a downstream test suite and was due to a control flow nesting merge issue. In-tree test added and fixed.pull/3100/head
parent
40e762ca42
commit
ffaaf08c31
|
@ -927,13 +927,13 @@ class ContextCache:
|
||||||
tensor_meta = node.meta.get("tensor_meta")
|
tensor_meta = node.meta.get("tensor_meta")
|
||||||
val = node.meta.get("val")
|
val = node.meta.get("val")
|
||||||
sparsity = node.meta.get("sparsity", None)
|
sparsity = node.meta.get("sparsity", None)
|
||||||
return self.value_info_to_type(
|
|
||||||
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
|
||||||
)
|
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||||
)
|
)
|
||||||
|
return self.value_info_to_type(
|
||||||
|
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
||||||
|
)
|
||||||
|
|
||||||
def value_info_to_type(
|
def value_info_to_type(
|
||||||
self,
|
self,
|
||||||
|
@ -962,13 +962,16 @@ class ContextCache:
|
||||||
return self.get_vtensor_type(
|
return self.get_vtensor_type(
|
||||||
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
# Note that None is a valid scalar here, so it is important that this
|
||||||
if t is not None:
|
# is always checked as the last fallback.
|
||||||
return IrType.parse(t, self._c)
|
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||||
|
if t is not None:
|
||||||
|
return IrType.parse(t, self._c)
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Could not deduce type from value info: "
|
f"Could not deduce type from value info: "
|
||||||
f"tensor_meta={tensor_meta}, val={val}, sparsity={sparsity}"
|
f"tensor_meta={tensor_meta}, val={val} {type(val)}, sparsity={sparsity}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def tensor_metadata_to_type(
|
def tensor_metadata_to_type(
|
||||||
|
@ -1631,7 +1634,9 @@ class GraphNodeImporter:
|
||||||
with loc:
|
with loc:
|
||||||
return cvt(arg, self, self._cc)
|
return cvt(arg, self, self._cc)
|
||||||
|
|
||||||
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]:
|
def _unpack_node_result_types(
|
||||||
|
self, node: torch.fx.Node, schema: FunctionSchema
|
||||||
|
) -> List[IrType]:
|
||||||
return_count = len(schema.returns)
|
return_count = len(schema.returns)
|
||||||
if return_count == 1:
|
if return_count == 1:
|
||||||
# Unary return directly maps a single meta["val"] and cannot be subscripted.
|
# Unary return directly maps a single meta["val"] and cannot be subscripted.
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
# This file contains tests of various op special forms that the fx_importer
|
||||||
|
# handles.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.export
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from torch_mlir import fx
|
||||||
|
|
||||||
|
|
||||||
|
def run(f):
|
||||||
|
print(f"{f.__name__}")
|
||||||
|
print("-" * len(f.__name__))
|
||||||
|
f()
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_scalar_typed_node
|
||||||
|
# Getting the shape of a dynamic dimension has the side effect of producing
|
||||||
|
# a node like:
|
||||||
|
# sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||||
|
# This tests the fx_importer code paths around resolving scalar/symbolic
|
||||||
|
# types for operands and results.
|
||||||
|
def test_scalar_typed_node():
|
||||||
|
class Basic(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + 1.0
|
||||||
|
return x.shape[0]
|
||||||
|
|
||||||
|
# CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int
|
||||||
|
m = fx.export_and_import(
|
||||||
|
Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}}
|
||||||
|
)
|
||||||
|
print(m)
|
Loading…
Reference in New Issue