mirror of https://github.com/llvm/torch-mlir
[FxImporter] Type conversion to resolve the mismatch between Py type and schema type (#3163)
parent
10b6062d41
commit
af5509c5d9
|
@ -1495,17 +1495,41 @@ class GraphNodeImporter:
|
||||||
with loc:
|
with loc:
|
||||||
self.bind_node_value(arg, self._import_literal(obj))
|
self.bind_node_value(arg, self._import_literal(obj))
|
||||||
|
|
||||||
return self.resolve_node_value(arg)
|
argument_value = self.resolve_node_value(arg)
|
||||||
elif isinstance(arg, torch_fx.immutable_collections.immutable_list):
|
elif isinstance(arg, torch_fx.immutable_collections.immutable_list):
|
||||||
return self._import_list_argument(loc, arg, expected_jit_type)
|
argument_value = self._import_list_argument(loc, arg, expected_jit_type)
|
||||||
elif isinstance(expected_jit_type, torch.TensorType) and not isinstance(
|
elif isinstance(expected_jit_type, torch.TensorType) and not isinstance(
|
||||||
arg, torch.Tensor
|
arg, torch.Tensor
|
||||||
):
|
):
|
||||||
# promote scalars to tensor types as appropriate
|
# promote scalars to tensor types as appropriate
|
||||||
return self._import_scalar_as_tensor(loc, arg)
|
argument_value = self._import_scalar_as_tensor(loc, arg)
|
||||||
else:
|
else:
|
||||||
with loc:
|
with loc:
|
||||||
return self._import_literal(arg)
|
argument_value = self._import_literal(arg)
|
||||||
|
return self._convert_type(loc, argument_value, expected_jit_type)
|
||||||
|
|
||||||
|
def _convert_type(self, loc: Location, val: Value, expected_jit_type):
|
||||||
|
"""
|
||||||
|
When the type of 'value' and the type in the schema do not match,
|
||||||
|
attempt to perform automatic type conversion.
|
||||||
|
|
||||||
|
example: test/python/fx_importer/basic_test.py::test_full
|
||||||
|
"""
|
||||||
|
op_name = None
|
||||||
|
result_type = None
|
||||||
|
# TODO: If additional types require conversion in the future,
|
||||||
|
# consider implementing a table-driven approach.
|
||||||
|
if val.type == self._cc.torch_bool_type:
|
||||||
|
if isinstance(expected_jit_type, torch.FloatType):
|
||||||
|
op_name = "torch.aten.Float.bool"
|
||||||
|
result_type = self._cc.torch_float_type
|
||||||
|
elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)):
|
||||||
|
op_name = "torch.aten.Int.bool"
|
||||||
|
result_type = self._cc.torch_int_type
|
||||||
|
if op_name is None:
|
||||||
|
return val
|
||||||
|
with loc:
|
||||||
|
return Operation.create(name=op_name, results=[result_type], operands=[val]).result
|
||||||
|
|
||||||
def _import_literal(self, py_value: Any) -> Value:
|
def _import_literal(self, py_value: Any) -> Value:
|
||||||
# Apply the conversion callback.
|
# Apply the conversion callback.
|
||||||
|
|
|
@ -26,6 +26,7 @@ def export_and_import(
|
||||||
hooks: Optional[FxImporterHooks] = None,
|
hooks: Optional[FxImporterHooks] = None,
|
||||||
decomposition_table: Optional[list] = None,
|
decomposition_table: Optional[list] = None,
|
||||||
func_name: str = "main",
|
func_name: str = "main",
|
||||||
|
enable_graph_printing: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
context = ir.Context()
|
context = ir.Context()
|
||||||
|
@ -38,6 +39,8 @@ def export_and_import(
|
||||||
decomposition_table = get_decomposition_table()
|
decomposition_table = get_decomposition_table()
|
||||||
if decomposition_table:
|
if decomposition_table:
|
||||||
prog = prog.run_decompositions(decomposition_table)
|
prog = prog.run_decompositions(decomposition_table)
|
||||||
|
if enable_graph_printing:
|
||||||
|
prog.graph_module.print_readable()
|
||||||
if experimental_support_mutation:
|
if experimental_support_mutation:
|
||||||
if torch.__version__ < "2.3.0.dev20240207":
|
if torch.__version__ < "2.3.0.dev20240207":
|
||||||
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
||||||
|
@ -53,7 +56,10 @@ def stateless_fx_import(
|
||||||
fx_importer: Optional[FxImporter] = None,
|
fx_importer: Optional[FxImporter] = None,
|
||||||
hooks: Optional[FxImporterHooks] = None,
|
hooks: Optional[FxImporterHooks] = None,
|
||||||
model_name: str = "main",
|
model_name: str = "main",
|
||||||
|
enable_graph_printing: bool = False,
|
||||||
):
|
):
|
||||||
|
if enable_graph_printing:
|
||||||
|
gm.print_readable()
|
||||||
context = ir.Context()
|
context = ir.Context()
|
||||||
torch_d.register_dialect(context)
|
torch_d.register_dialect(context)
|
||||||
if fx_importer is None:
|
if fx_importer is None:
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torch._dynamo.backends.common import aot_autograd
|
||||||
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name
|
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name
|
||||||
|
|
||||||
from torch_mlir import fx
|
from torch_mlir import fx
|
||||||
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||||
|
|
||||||
|
|
||||||
def run(f):
|
def run(f):
|
||||||
|
@ -118,3 +119,24 @@ def test_stateless_fx_import():
|
||||||
return torch.tanh(x)
|
return torch.tanh(x)
|
||||||
|
|
||||||
basic_forward(torch.randn(3, 4))
|
basic_forward(torch.randn(3, 4))
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_full
|
||||||
|
# CHECK: %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[],i1>, !torch.int -> !torch.vtensor<[],i1>
|
||||||
|
def test_full():
|
||||||
|
class Basic(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return torch.full([], False, dtype=torch.bool, layout=torch.strided, device='cpu',
|
||||||
|
pin_memory=False)
|
||||||
|
|
||||||
|
m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True)
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
m,
|
||||||
|
f"builtin.module(torch-simplification-pipeline)",
|
||||||
|
"torch-simplification-pipeline",
|
||||||
|
)
|
||||||
|
print(m)
|
||||||
|
|
Loading…
Reference in New Issue