Allow custom ops to return an array of tensors (#3531)

This PR adds support to `fx_importer.py` for handling custom ops that
return an array of tensors. As long as the length of the array is
consistent across runs (determined statically), then this patch will
work. This does not require that the number of tensors returned is
determined by the op's definition.

CC @sjain-stanford
pull/3543/head
Matthew Francis-Landau 2024-07-14 14:54:23 -04:00 committed by GitHub
parent 7411ff2f69
commit fe9db78120
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 106 additions and 23 deletions

View File

@ -1099,6 +1099,10 @@ class ContextCache:
return self.get_vtensor_type( return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity, mutable=mutable val.size(), val.dtype, sparsity=sparsity, mutable=mutable
) )
elif isinstance(val, list) and all(
isinstance(x, TorchFakeTensor) for x in val
):
return IrType.parse("!torch.list<vtensor>", context=self._c)
# Note that None is a valid scalar here, so it is important that this # Note that None is a valid scalar here, so it is important that this
# is always checked as the last fallback. # is always checked as the last fallback.
@ -1227,6 +1231,7 @@ class GraphNodeImporter:
"_v", "_v",
"_symbol_to_value", "_symbol_to_value",
"_multi_result_nodes", "_multi_result_nodes",
"_unpack_list_values",
"fx_importer", "fx_importer",
] ]
@ -1251,6 +1256,10 @@ class GraphNodeImporter:
# Statically multi-result nodes which we have de-tupled are noted here. # Statically multi-result nodes which we have de-tupled are noted here.
# They will have their getitem calls short-circuited. # They will have their getitem calls short-circuited.
self._multi_result_nodes: Set[torch_fx.Node] = set() self._multi_result_nodes: Set[torch_fx.Node] = set()
# If a OP returns a list, then it needs to be unpacked entirely using
# prim.ListUnpack. Cache the result of these nodes so that it only
# unpacks once instead of every time that getitem is used
self._unpack_list_values: Dict[torch_fx.Node, Tuple[Value]] = {}
def bind_node_value( def bind_node_value(
self, self,
@ -1420,29 +1429,7 @@ class GraphNodeImporter:
elif op == "call_function": elif op == "call_function":
target = node.target target = node.target
if target == operator.getitem: if target == operator.getitem:
# Special case handling of getitem for when it is resolving self._import_getitem(loc, node)
# against a function call that we know has returned multiple
# results. We short-circuit this case because we have modeled
# function calls to natively return multiple results vs tupling.
getitem_ref, getitem_index = node.args
if getitem_ref in self._multi_result_nodes:
try:
self.bind_node_value(
node,
self.resolve_node_value(getitem_ref, getitem_index),
)
except IndexError:
raise RuntimeError(
f"getitem de-aliasing failed. This likely "
f"indicates a programmer error that usually "
f"would have happened at runtime. Please "
f"notify developers if this case happens "
f"(at {loc})."
)
else:
raise NotImplementedError(
f"General getitem access to non-multi-result ops"
)
elif target in SYMBOLIC_TORCH_OPS or ( elif target in SYMBOLIC_TORCH_OPS or (
is_symbolic(node.meta.get("val")) is_symbolic(node.meta.get("val"))
and is_builtin_function_or_method(target) and is_builtin_function_or_method(target)
@ -2007,6 +1994,51 @@ class GraphNodeImporter:
with loc: with loc:
return cvt(arg, self, self._cc) return cvt(arg, self, self._cc)
def _import_getitem(self, loc: Location, node: torch.fx.Node):
ref_node, index = node.args
if ref_node in self._multi_result_nodes:
# Special case handling of getitem for when it is resolving
# against a function call that we know has returned multiple
# results. We short-circuit this case because we have modeled
# function calls to natively return multiple results vs tupling.
try:
self.bind_node_value(
node,
self.resolve_node_value(ref_node, index),
)
except IndexError:
raise RuntimeError(
f"getitem de-aliasing failed. This likely "
f"indicates a programmer error that usually "
f"would have happened at runtime. Please "
f"notify developers if this case happens "
f"(at {loc})."
)
else:
# handle nodes that return a torch.list<...> at the MLIR level
# NOTE: the length of the list must be knowable at compile time.
if ref_node not in self._unpack_list_values:
node_result = self.resolve_node_value(ref_node, 0)
if str(node_result.type) in TORCH_LIST_TYPES:
result_types = [
self._cc.value_info_to_type(v) for v in ref_node.meta["val"]
]
operation = Operation.create(
"torch.prim.ListUnpack",
results=result_types,
operands=[node_result],
loc=loc,
)
self._unpack_list_values[ref_node] = tuple(operation.results)
try:
self.bind_node_value(node, self._unpack_list_values[ref_node][index])
except IndexError:
raise RuntimeError(
f"getitem failed. "
f"getitem only supports lists of known length. (at {loc})"
)
def _unpack_node_result_types( def _unpack_node_result_types(
self, node: torch.fx.Node, schema: FunctionSchema self, node: torch.fx.Node, schema: FunctionSchema
) -> List[IrType]: ) -> List[IrType]:
@ -2337,6 +2369,10 @@ PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE = {
"vtensor": "!torch.list<optional<vtensor>>", "vtensor": "!torch.list<optional<vtensor>>",
} }
TORCH_LIST_TYPES = set(PY_TYPE_TO_TORCH_LIST_TYPE.values()) | set(
PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE.values()
)
SCALAR_TYPE_TO_TORCH_MLIR_TYPE = { SCALAR_TYPE_TO_TORCH_MLIR_TYPE = {
torch.SymInt: "!torch.int", torch.SymInt: "!torch.int",
torch.SymFloat: "!torch.float", torch.SymFloat: "!torch.float",

View File

@ -84,3 +84,50 @@ def test_tanh_sigmoid_cat_custom_op():
import_symbolic_shape_expressions=True, import_symbolic_shape_expressions=True,
) )
print(m) print(m)
@run
# CHECK-LABEL: test_custom_op_array_output
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>)
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int
# CHECK: %[[int:.+]] = torch.constant.int 4
# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list<vtensor>
# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list<vtensor> -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#0, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#1, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#2, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#3, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: return %[[V1]]#0, %[[V1]]#1, %[[V1]]#2, %[[V1]]#3 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>
def test_custom_op_array_output():
m = Library("my_custom_library", "DEF")
m.define("array_output_op(int num_outs, Tensor a) -> Tensor[]")
@impl(m, "array_output_op", "CompositeExplicitAutograd")
def custom_op(num_outs, a):
return [a] * num_outs
@impl_abstract("my_custom_library::array_output_op")
def custom_op_meta(num_outs, a):
result = custom_op(num_outs, a)
return [torch.empty_like(t) for t in result]
class ArrayOutputCustomOp(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a):
return torch.ops.my_custom_library.array_output_op(4, a)
dim = Dim("n", max=10)
dynamic_shapes = {
"a": {0: dim},
}
a = torch.rand(2, 3)
m = fx.export_and_import(
ArrayOutputCustomOp(),
a,
import_symbolic_shape_expressions=True,
dynamic_shapes=dynamic_shapes,
)
print(m)