[FxImporter] Fix constant bool tensor (#3375)

pull/3346/merge
penguin_wwy 2024-05-22 22:59:01 +08:00 committed by GitHub
parent 52be4bdc18
commit 972d47b586
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 16 deletions

View File

@ -367,7 +367,6 @@ FX_IMPORTER_XFAIL_SET = {
"BoolIntTrueModule_basic",
"BroadcastDynamicDimModule_basic",
"CeilFloatModule_basic",
"ConstantBoolParameterModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"Conv2dQInt8Module_basic",

View File

@ -844,6 +844,10 @@ class FxImporter:
result_types.append(
IrType.parse("!torch.none", context=self._c)
)
elif isinstance(result_node, torch.Tensor):
result_types.append(
self._cc.tensor_to_vtensor_type(result_node)
)
else:
result_types.append(self._cc.node_val_to_type(result_node))
return (
@ -1002,9 +1006,14 @@ class ContextCache:
self._dtype_to_type[dtype] = t
return t
def create_vtensor_type(self, dtype: torch.dtype, size: torch.Size) -> IrType:
dtype_asm = str(self.dtype_to_type(dtype))
return IrType.parse(
f"!torch.vtensor<{list(size)},{dtype_asm}>", context=self._c
)
def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType:
dtype_asm = str(self.dtype_to_type(tensor.dtype))
return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>")
return self.create_vtensor_type(tensor.dtype, tensor.size())
def get_node_location(self, node: torch_fx.Node) -> Optional[Location]:
stack_trace = node.meta.get("stack_trace")
@ -1513,37 +1522,58 @@ class GraphNodeImporter:
):
# promote scalars to tensor types as appropriate
argument_value = self._import_scalar_as_tensor(loc, arg)
else:
elif LITERAL_CONVERTER_MAP.lookup(type(arg)) is not None:
with loc:
argument_value = self._import_literal(arg)
return self._convert_type(loc, argument_value, expected_jit_type)
else:
raise TypeError(f"Unsupported argument type {arg.__class__}")
with loc:
return self._convert_type(argument_value, expected_jit_type)
def _convert_type(self, loc: Location, val: Value, expected_jit_type):
def _convert_type(
self,
val: Value,
expected_type,
dtype: Optional[torch.dtype] = None,
size: Optional[torch.Size] = None,
):
"""
When the type of 'value' and the type in the schema do not match,
attempt to perform automatic type conversion.
example: test/python/fx_importer/basic_test.py::test_full
"""
if not expected_type:
return val
op_name = None
result_type = None
# TODO: If additional types require conversion in the future,
# consider implementing a table-driven approach.
operands = [val]
if val.type == self._cc.torch_bool_type:
if isinstance(expected_jit_type, torch.FloatType):
if isinstance(expected_type, torch.FloatType):
op_name = "torch.aten.Float.bool"
result_type = self._cc.torch_float_type
elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)):
elif isinstance(expected_type, (torch.IntType, torch.NumberType)):
op_name = "torch.aten.Int.bool"
result_type = self._cc.torch_int_type
elif expected_type is torch.Tensor:
op_name = "torch.prims.convert_element_type"
result_type = self._cc.create_vtensor_type(dtype, size)
operands.append(
LITERAL_CONVERTER_MAP.lookup(torch.dtype)(dtype, self, self._cc)
)
if op_name is None:
return val
with loc:
return Operation.create(
name=op_name, results=[result_type], operands=[val]
).result
return Operation.create(
name=op_name, results=[result_type], operands=operands
).result
def _import_literal(self, py_value: Any) -> Value:
orig_value = None
if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool:
orig_value = py_value
py_value = py_value.to(torch.uint8)
# Apply the conversion callback.
user_value = self.fx_importer._hooks.resolve_literal(self, py_value)
if user_value is not None:
@ -1556,7 +1586,12 @@ class GraphNodeImporter:
raise TypeError(
f"Unsupported argument -> literal conversion for {py_value.__class__}"
)
return converter(py_value, self, self._cc)
result = converter(py_value, self, self._cc)
if orig_value is not None:
result = self._convert_type(
result, torch.Tensor, orig_value.dtype, orig_value.size()
)
return result
def _import_input(self, py_value: Any, info: InputInfo) -> Value:
# Try the hook.
@ -1704,16 +1739,19 @@ def _make_constant_op(
)
def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
def _create_mlir_tensor_type(dtype: torch.dtype, size: torch.Size) -> IrType:
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
tensor_type = RankedTensorType.get(size, element_type)
return tensor_type
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
return _create_mlir_tensor_type(tensor.dtype, tensor.size())
def _make_vtensor_literal_op(
tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker"
) -> Operation: