[FxImporter] Type conversion to resolve the mismatch between Py type and schema type (#3163)

pull/3170/head
penguin_wwy 2024-04-16 14:14:19 +08:00 committed by GitHub
parent 10b6062d41
commit af5509c5d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 4 deletions

View File

@ -1495,17 +1495,41 @@ class GraphNodeImporter:
with loc: with loc:
self.bind_node_value(arg, self._import_literal(obj)) self.bind_node_value(arg, self._import_literal(obj))
return self.resolve_node_value(arg) argument_value = self.resolve_node_value(arg)
elif isinstance(arg, torch_fx.immutable_collections.immutable_list): elif isinstance(arg, torch_fx.immutable_collections.immutable_list):
return self._import_list_argument(loc, arg, expected_jit_type) argument_value = self._import_list_argument(loc, arg, expected_jit_type)
elif isinstance(expected_jit_type, torch.TensorType) and not isinstance( elif isinstance(expected_jit_type, torch.TensorType) and not isinstance(
arg, torch.Tensor arg, torch.Tensor
): ):
# promote scalars to tensor types as appropriate # promote scalars to tensor types as appropriate
return self._import_scalar_as_tensor(loc, arg) argument_value = self._import_scalar_as_tensor(loc, arg)
else: else:
with loc: with loc:
return self._import_literal(arg) argument_value = self._import_literal(arg)
return self._convert_type(loc, argument_value, expected_jit_type)
def _convert_type(self, loc: Location, val: Value, expected_jit_type):
"""
When the type of 'value' and the type in the schema do not match,
attempt to perform automatic type conversion.
example: test/python/fx_importer/basic_test.py::test_full
"""
op_name = None
result_type = None
# TODO: If additional types require conversion in the future,
# consider implementing a table-driven approach.
if val.type == self._cc.torch_bool_type:
if isinstance(expected_jit_type, torch.FloatType):
op_name = "torch.aten.Float.bool"
result_type = self._cc.torch_float_type
elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)):
op_name = "torch.aten.Int.bool"
result_type = self._cc.torch_int_type
if op_name is None:
return val
with loc:
return Operation.create(name=op_name, results=[result_type], operands=[val]).result
def _import_literal(self, py_value: Any) -> Value: def _import_literal(self, py_value: Any) -> Value:
# Apply the conversion callback. # Apply the conversion callback.

View File

@ -26,6 +26,7 @@ def export_and_import(
hooks: Optional[FxImporterHooks] = None, hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[list] = None, decomposition_table: Optional[list] = None,
func_name: str = "main", func_name: str = "main",
enable_graph_printing: bool = False,
**kwargs, **kwargs,
): ):
context = ir.Context() context = ir.Context()
@ -38,6 +39,8 @@ def export_and_import(
decomposition_table = get_decomposition_table() decomposition_table = get_decomposition_table()
if decomposition_table: if decomposition_table:
prog = prog.run_decompositions(decomposition_table) prog = prog.run_decompositions(decomposition_table)
if enable_graph_printing:
prog.graph_module.print_readable()
if experimental_support_mutation: if experimental_support_mutation:
if torch.__version__ < "2.3.0.dev20240207": if torch.__version__ < "2.3.0.dev20240207":
warnings.warn("Mutable program import only supported on PyTorch 2.3+") warnings.warn("Mutable program import only supported on PyTorch 2.3+")
@ -53,7 +56,10 @@ def stateless_fx_import(
fx_importer: Optional[FxImporter] = None, fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None, hooks: Optional[FxImporterHooks] = None,
model_name: str = "main", model_name: str = "main",
enable_graph_printing: bool = False,
): ):
if enable_graph_printing:
gm.print_readable()
context = ir.Context() context = ir.Context()
torch_d.register_dialect(context) torch_d.register_dialect(context)
if fx_importer is None: if fx_importer is None:

View File

@ -14,6 +14,7 @@ from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name
from torch_mlir import fx from torch_mlir import fx
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
def run(f): def run(f):
@ -118,3 +119,24 @@ def test_stateless_fx_import():
return torch.tanh(x) return torch.tanh(x)
basic_forward(torch.randn(3, 4)) basic_forward(torch.randn(3, 4))
@run
# CHECK-LABEL: test_full
# CHECK: %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[],i1>, !torch.int -> !torch.vtensor<[],i1>
def test_full():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch.full([], False, dtype=torch.bool, layout=torch.strided, device='cpu',
pin_memory=False)
m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True)
run_pipeline_with_repro_report(
m,
f"builtin.module(torch-simplification-pipeline)",
"torch-simplification-pipeline",
)
print(m)