mirror of https://github.com/llvm/torch-mlir
[FxImporter] Fix constant bool tensor (#3375)
parent
52be4bdc18
commit
972d47b586
|
@ -367,7 +367,6 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"BoolIntTrueModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
|
|
|
@ -844,6 +844,10 @@ class FxImporter:
|
|||
result_types.append(
|
||||
IrType.parse("!torch.none", context=self._c)
|
||||
)
|
||||
elif isinstance(result_node, torch.Tensor):
|
||||
result_types.append(
|
||||
self._cc.tensor_to_vtensor_type(result_node)
|
||||
)
|
||||
else:
|
||||
result_types.append(self._cc.node_val_to_type(result_node))
|
||||
return (
|
||||
|
@ -1002,9 +1006,14 @@ class ContextCache:
|
|||
self._dtype_to_type[dtype] = t
|
||||
return t
|
||||
|
||||
def create_vtensor_type(self, dtype: torch.dtype, size: torch.Size) -> IrType:
|
||||
dtype_asm = str(self.dtype_to_type(dtype))
|
||||
return IrType.parse(
|
||||
f"!torch.vtensor<{list(size)},{dtype_asm}>", context=self._c
|
||||
)
|
||||
|
||||
def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType:
|
||||
dtype_asm = str(self.dtype_to_type(tensor.dtype))
|
||||
return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>")
|
||||
return self.create_vtensor_type(tensor.dtype, tensor.size())
|
||||
|
||||
def get_node_location(self, node: torch_fx.Node) -> Optional[Location]:
|
||||
stack_trace = node.meta.get("stack_trace")
|
||||
|
@ -1513,37 +1522,58 @@ class GraphNodeImporter:
|
|||
):
|
||||
# promote scalars to tensor types as appropriate
|
||||
argument_value = self._import_scalar_as_tensor(loc, arg)
|
||||
else:
|
||||
elif LITERAL_CONVERTER_MAP.lookup(type(arg)) is not None:
|
||||
with loc:
|
||||
argument_value = self._import_literal(arg)
|
||||
return self._convert_type(loc, argument_value, expected_jit_type)
|
||||
else:
|
||||
raise TypeError(f"Unsupported argument type {arg.__class__}")
|
||||
with loc:
|
||||
return self._convert_type(argument_value, expected_jit_type)
|
||||
|
||||
def _convert_type(self, loc: Location, val: Value, expected_jit_type):
|
||||
def _convert_type(
|
||||
self,
|
||||
val: Value,
|
||||
expected_type,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
size: Optional[torch.Size] = None,
|
||||
):
|
||||
"""
|
||||
When the type of 'value' and the type in the schema do not match,
|
||||
attempt to perform automatic type conversion.
|
||||
|
||||
example: test/python/fx_importer/basic_test.py::test_full
|
||||
"""
|
||||
if not expected_type:
|
||||
return val
|
||||
op_name = None
|
||||
result_type = None
|
||||
# TODO: If additional types require conversion in the future,
|
||||
# consider implementing a table-driven approach.
|
||||
operands = [val]
|
||||
if val.type == self._cc.torch_bool_type:
|
||||
if isinstance(expected_jit_type, torch.FloatType):
|
||||
if isinstance(expected_type, torch.FloatType):
|
||||
op_name = "torch.aten.Float.bool"
|
||||
result_type = self._cc.torch_float_type
|
||||
elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)):
|
||||
elif isinstance(expected_type, (torch.IntType, torch.NumberType)):
|
||||
op_name = "torch.aten.Int.bool"
|
||||
result_type = self._cc.torch_int_type
|
||||
elif expected_type is torch.Tensor:
|
||||
op_name = "torch.prims.convert_element_type"
|
||||
result_type = self._cc.create_vtensor_type(dtype, size)
|
||||
operands.append(
|
||||
LITERAL_CONVERTER_MAP.lookup(torch.dtype)(dtype, self, self._cc)
|
||||
)
|
||||
if op_name is None:
|
||||
return val
|
||||
with loc:
|
||||
return Operation.create(
|
||||
name=op_name, results=[result_type], operands=[val]
|
||||
).result
|
||||
return Operation.create(
|
||||
name=op_name, results=[result_type], operands=operands
|
||||
).result
|
||||
|
||||
def _import_literal(self, py_value: Any) -> Value:
|
||||
orig_value = None
|
||||
if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool:
|
||||
orig_value = py_value
|
||||
py_value = py_value.to(torch.uint8)
|
||||
# Apply the conversion callback.
|
||||
user_value = self.fx_importer._hooks.resolve_literal(self, py_value)
|
||||
if user_value is not None:
|
||||
|
@ -1556,7 +1586,12 @@ class GraphNodeImporter:
|
|||
raise TypeError(
|
||||
f"Unsupported argument -> literal conversion for {py_value.__class__}"
|
||||
)
|
||||
return converter(py_value, self, self._cc)
|
||||
result = converter(py_value, self, self._cc)
|
||||
if orig_value is not None:
|
||||
result = self._convert_type(
|
||||
result, torch.Tensor, orig_value.dtype, orig_value.size()
|
||||
)
|
||||
return result
|
||||
|
||||
def _import_input(self, py_value: Any, info: InputInfo) -> Value:
|
||||
# Try the hook.
|
||||
|
@ -1704,16 +1739,19 @@ def _make_constant_op(
|
|||
)
|
||||
|
||||
|
||||
def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
|
||||
def _create_mlir_tensor_type(dtype: torch.dtype, size: torch.Size) -> IrType:
|
||||
try:
|
||||
dtype = tensor.dtype
|
||||
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
|
||||
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
|
||||
tensor_type = RankedTensorType.get(size, element_type)
|
||||
return tensor_type
|
||||
except KeyError:
|
||||
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
|
||||
|
||||
|
||||
def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
|
||||
return _create_mlir_tensor_type(tensor.dtype, tensor.size())
|
||||
|
||||
|
||||
def _make_vtensor_literal_op(
|
||||
tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker"
|
||||
) -> Operation:
|
||||
|
|
Loading…
Reference in New Issue