diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2ae767dde..ab162ab94 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -367,7 +367,6 @@ FX_IMPORTER_XFAIL_SET = { "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", "CeilFloatModule_basic", - "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv2dQInt8Module_basic", diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index c931a3b93..34d570b55 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -844,6 +844,10 @@ class FxImporter: result_types.append( IrType.parse("!torch.none", context=self._c) ) + elif isinstance(result_node, torch.Tensor): + result_types.append( + self._cc.tensor_to_vtensor_type(result_node) + ) else: result_types.append(self._cc.node_val_to_type(result_node)) return ( @@ -1002,9 +1006,14 @@ class ContextCache: self._dtype_to_type[dtype] = t return t + def create_vtensor_type(self, dtype: torch.dtype, size: torch.Size) -> IrType: + dtype_asm = str(self.dtype_to_type(dtype)) + return IrType.parse( + f"!torch.vtensor<{list(size)},{dtype_asm}>", context=self._c + ) + def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType: - dtype_asm = str(self.dtype_to_type(tensor.dtype)) - return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>") + return self.create_vtensor_type(tensor.dtype, tensor.size()) def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: stack_trace = node.meta.get("stack_trace") @@ -1513,37 +1522,58 @@ class GraphNodeImporter: ): # promote scalars to tensor types as appropriate argument_value = self._import_scalar_as_tensor(loc, arg) - else: + elif LITERAL_CONVERTER_MAP.lookup(type(arg)) is not None: with loc: argument_value = self._import_literal(arg) - return self._convert_type(loc, argument_value, expected_jit_type) + else: + raise TypeError(f"Unsupported argument type {arg.__class__}") + with loc: + return self._convert_type(argument_value, expected_jit_type) - def _convert_type(self, loc: Location, val: Value, expected_jit_type): + def _convert_type( + self, + val: Value, + expected_type, + dtype: Optional[torch.dtype] = None, + size: Optional[torch.Size] = None, + ): """ When the type of 'value' and the type in the schema do not match, attempt to perform automatic type conversion. example: test/python/fx_importer/basic_test.py::test_full """ + if not expected_type: + return val op_name = None result_type = None # TODO: If additional types require conversion in the future, # consider implementing a table-driven approach. + operands = [val] if val.type == self._cc.torch_bool_type: - if isinstance(expected_jit_type, torch.FloatType): + if isinstance(expected_type, torch.FloatType): op_name = "torch.aten.Float.bool" result_type = self._cc.torch_float_type - elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)): + elif isinstance(expected_type, (torch.IntType, torch.NumberType)): op_name = "torch.aten.Int.bool" result_type = self._cc.torch_int_type + elif expected_type is torch.Tensor: + op_name = "torch.prims.convert_element_type" + result_type = self._cc.create_vtensor_type(dtype, size) + operands.append( + LITERAL_CONVERTER_MAP.lookup(torch.dtype)(dtype, self, self._cc) + ) if op_name is None: return val - with loc: - return Operation.create( - name=op_name, results=[result_type], operands=[val] - ).result + return Operation.create( + name=op_name, results=[result_type], operands=operands + ).result def _import_literal(self, py_value: Any) -> Value: + orig_value = None + if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool: + orig_value = py_value + py_value = py_value.to(torch.uint8) # Apply the conversion callback. user_value = self.fx_importer._hooks.resolve_literal(self, py_value) if user_value is not None: @@ -1556,7 +1586,12 @@ class GraphNodeImporter: raise TypeError( f"Unsupported argument -> literal conversion for {py_value.__class__}" ) - return converter(py_value, self, self._cc) + result = converter(py_value, self, self._cc) + if orig_value is not None: + result = self._convert_type( + result, torch.Tensor, orig_value.dtype, orig_value.size() + ) + return result def _import_input(self, py_value: Any, info: InputInfo) -> Value: # Try the hook. @@ -1704,16 +1739,19 @@ def _make_constant_op( ) -def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: +def _create_mlir_tensor_type(dtype: torch.dtype, size: torch.Size) -> IrType: try: - dtype = tensor.dtype element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() - tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) + tensor_type = RankedTensorType.get(size, element_type) return tensor_type except KeyError: raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") +def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: + return _create_mlir_tensor_type(tensor.dtype, tensor.size()) + + def _make_vtensor_literal_op( tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker" ) -> Operation: