diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index cb86406c5..c95df2504 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1099,6 +1099,10 @@ class ContextCache: return self.get_vtensor_type( val.size(), val.dtype, sparsity=sparsity, mutable=mutable ) + elif isinstance(val, list) and all( + isinstance(x, TorchFakeTensor) for x in val + ): + return IrType.parse("!torch.list", context=self._c) # Note that None is a valid scalar here, so it is important that this # is always checked as the last fallback. @@ -1227,6 +1231,7 @@ class GraphNodeImporter: "_v", "_symbol_to_value", "_multi_result_nodes", + "_unpack_list_values", "fx_importer", ] @@ -1251,6 +1256,10 @@ class GraphNodeImporter: # Statically multi-result nodes which we have de-tupled are noted here. # They will have their getitem calls short-circuited. self._multi_result_nodes: Set[torch_fx.Node] = set() + # If a OP returns a list, then it needs to be unpacked entirely using + # prim.ListUnpack. Cache the result of these nodes so that it only + # unpacks once instead of every time that getitem is used + self._unpack_list_values: Dict[torch_fx.Node, Tuple[Value]] = {} def bind_node_value( self, @@ -1420,29 +1429,7 @@ class GraphNodeImporter: elif op == "call_function": target = node.target if target == operator.getitem: - # Special case handling of getitem for when it is resolving - # against a function call that we know has returned multiple - # results. We short-circuit this case because we have modeled - # function calls to natively return multiple results vs tupling. - getitem_ref, getitem_index = node.args - if getitem_ref in self._multi_result_nodes: - try: - self.bind_node_value( - node, - self.resolve_node_value(getitem_ref, getitem_index), - ) - except IndexError: - raise RuntimeError( - f"getitem de-aliasing failed. This likely " - f"indicates a programmer error that usually " - f"would have happened at runtime. Please " - f"notify developers if this case happens " - f"(at {loc})." - ) - else: - raise NotImplementedError( - f"General getitem access to non-multi-result ops" - ) + self._import_getitem(loc, node) elif target in SYMBOLIC_TORCH_OPS or ( is_symbolic(node.meta.get("val")) and is_builtin_function_or_method(target) @@ -2007,6 +1994,51 @@ class GraphNodeImporter: with loc: return cvt(arg, self, self._cc) + def _import_getitem(self, loc: Location, node: torch.fx.Node): + ref_node, index = node.args + if ref_node in self._multi_result_nodes: + # Special case handling of getitem for when it is resolving + # against a function call that we know has returned multiple + # results. We short-circuit this case because we have modeled + # function calls to natively return multiple results vs tupling. + try: + self.bind_node_value( + node, + self.resolve_node_value(ref_node, index), + ) + except IndexError: + raise RuntimeError( + f"getitem de-aliasing failed. This likely " + f"indicates a programmer error that usually " + f"would have happened at runtime. Please " + f"notify developers if this case happens " + f"(at {loc})." + ) + else: + # handle nodes that return a torch.list<...> at the MLIR level + # NOTE: the length of the list must be knowable at compile time. + if ref_node not in self._unpack_list_values: + node_result = self.resolve_node_value(ref_node, 0) + if str(node_result.type) in TORCH_LIST_TYPES: + result_types = [ + self._cc.value_info_to_type(v) for v in ref_node.meta["val"] + ] + operation = Operation.create( + "torch.prim.ListUnpack", + results=result_types, + operands=[node_result], + loc=loc, + ) + self._unpack_list_values[ref_node] = tuple(operation.results) + + try: + self.bind_node_value(node, self._unpack_list_values[ref_node][index]) + except IndexError: + raise RuntimeError( + f"getitem failed. " + f"getitem only supports lists of known length. (at {loc})" + ) + def _unpack_node_result_types( self, node: torch.fx.Node, schema: FunctionSchema ) -> List[IrType]: @@ -2337,6 +2369,10 @@ PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE = { "vtensor": "!torch.list>", } +TORCH_LIST_TYPES = set(PY_TYPE_TO_TORCH_LIST_TYPE.values()) | set( + PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE.values() +) + SCALAR_TYPE_TO_TORCH_MLIR_TYPE = { torch.SymInt: "!torch.int", torch.SymFloat: "!torch.float", diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py index dbbc5ba05..d9ce7d609 100644 --- a/test/python/fx_importer/custom_op_test.py +++ b/test/python/fx_importer/custom_op_test.py @@ -84,3 +84,50 @@ def test_tanh_sigmoid_cat_custom_op(): import_symbolic_shape_expressions=True, ) print(m) + + +@run +# CHECK-LABEL: test_custom_op_array_output +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>) +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int +# CHECK: %[[int:.+]] = torch.constant.int 4 +# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list +# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#0, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#1, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#2, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#3, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: return %[[V1]]#0, %[[V1]]#1, %[[V1]]#2, %[[V1]]#3 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +def test_custom_op_array_output(): + m = Library("my_custom_library", "DEF") + m.define("array_output_op(int num_outs, Tensor a) -> Tensor[]") + + @impl(m, "array_output_op", "CompositeExplicitAutograd") + def custom_op(num_outs, a): + return [a] * num_outs + + @impl_abstract("my_custom_library::array_output_op") + def custom_op_meta(num_outs, a): + result = custom_op(num_outs, a) + return [torch.empty_like(t) for t in result] + + class ArrayOutputCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return torch.ops.my_custom_library.array_output_op(4, a) + + dim = Dim("n", max=10) + dynamic_shapes = { + "a": {0: dim}, + } + + a = torch.rand(2, 3) + m = fx.export_and_import( + ArrayOutputCustomOp(), + a, + import_symbolic_shape_expressions=True, + dynamic_shapes=dynamic_shapes, + ) + print(m)