mirror of https://github.com/llvm/torch-mlir
[FxImporter] Type conversion to resolve the mismatch between Py type and schema type (#3163)
parent
10b6062d41
commit
af5509c5d9
|
@ -1495,17 +1495,41 @@ class GraphNodeImporter:
|
|||
with loc:
|
||||
self.bind_node_value(arg, self._import_literal(obj))
|
||||
|
||||
return self.resolve_node_value(arg)
|
||||
argument_value = self.resolve_node_value(arg)
|
||||
elif isinstance(arg, torch_fx.immutable_collections.immutable_list):
|
||||
return self._import_list_argument(loc, arg, expected_jit_type)
|
||||
argument_value = self._import_list_argument(loc, arg, expected_jit_type)
|
||||
elif isinstance(expected_jit_type, torch.TensorType) and not isinstance(
|
||||
arg, torch.Tensor
|
||||
):
|
||||
# promote scalars to tensor types as appropriate
|
||||
return self._import_scalar_as_tensor(loc, arg)
|
||||
argument_value = self._import_scalar_as_tensor(loc, arg)
|
||||
else:
|
||||
with loc:
|
||||
return self._import_literal(arg)
|
||||
argument_value = self._import_literal(arg)
|
||||
return self._convert_type(loc, argument_value, expected_jit_type)
|
||||
|
||||
def _convert_type(self, loc: Location, val: Value, expected_jit_type):
|
||||
"""
|
||||
When the type of 'value' and the type in the schema do not match,
|
||||
attempt to perform automatic type conversion.
|
||||
|
||||
example: test/python/fx_importer/basic_test.py::test_full
|
||||
"""
|
||||
op_name = None
|
||||
result_type = None
|
||||
# TODO: If additional types require conversion in the future,
|
||||
# consider implementing a table-driven approach.
|
||||
if val.type == self._cc.torch_bool_type:
|
||||
if isinstance(expected_jit_type, torch.FloatType):
|
||||
op_name = "torch.aten.Float.bool"
|
||||
result_type = self._cc.torch_float_type
|
||||
elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)):
|
||||
op_name = "torch.aten.Int.bool"
|
||||
result_type = self._cc.torch_int_type
|
||||
if op_name is None:
|
||||
return val
|
||||
with loc:
|
||||
return Operation.create(name=op_name, results=[result_type], operands=[val]).result
|
||||
|
||||
def _import_literal(self, py_value: Any) -> Value:
|
||||
# Apply the conversion callback.
|
||||
|
|
|
@ -26,6 +26,7 @@ def export_and_import(
|
|||
hooks: Optional[FxImporterHooks] = None,
|
||||
decomposition_table: Optional[list] = None,
|
||||
func_name: str = "main",
|
||||
enable_graph_printing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
context = ir.Context()
|
||||
|
@ -38,6 +39,8 @@ def export_and_import(
|
|||
decomposition_table = get_decomposition_table()
|
||||
if decomposition_table:
|
||||
prog = prog.run_decompositions(decomposition_table)
|
||||
if enable_graph_printing:
|
||||
prog.graph_module.print_readable()
|
||||
if experimental_support_mutation:
|
||||
if torch.__version__ < "2.3.0.dev20240207":
|
||||
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
||||
|
@ -53,7 +56,10 @@ def stateless_fx_import(
|
|||
fx_importer: Optional[FxImporter] = None,
|
||||
hooks: Optional[FxImporterHooks] = None,
|
||||
model_name: str = "main",
|
||||
enable_graph_printing: bool = False,
|
||||
):
|
||||
if enable_graph_printing:
|
||||
gm.print_readable()
|
||||
context = ir.Context()
|
||||
torch_d.register_dialect(context)
|
||||
if fx_importer is None:
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch._dynamo.backends.common import aot_autograd
|
|||
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name
|
||||
|
||||
from torch_mlir import fx
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
|
||||
def run(f):
|
||||
|
@ -118,3 +119,24 @@ def test_stateless_fx_import():
|
|||
return torch.tanh(x)
|
||||
|
||||
basic_forward(torch.randn(3, 4))
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_full
|
||||
# CHECK: %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[],i1>, !torch.int -> !torch.vtensor<[],i1>
|
||||
def test_full():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self):
|
||||
return torch.full([], False, dtype=torch.bool, layout=torch.strided, device='cpu',
|
||||
pin_memory=False)
|
||||
|
||||
m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True)
|
||||
run_pipeline_with_repro_report(
|
||||
m,
|
||||
f"builtin.module(torch-simplification-pipeline)",
|
||||
"torch-simplification-pipeline",
|
||||
)
|
||||
print(m)
|
||||
|
|
Loading…
Reference in New Issue