mirror of https://github.com/llvm/torch-mlir
Allow custom ops to return an array of tensors (#3531)
This PR adds support to `fx_importer.py` for handling custom ops that return an array of tensors. As long as the length of the array is consistent across runs (determined statically), then this patch will work. This does not require that the number of tensors returned is determined by the op's definition. CC @sjain-stanfordpull/3543/head
parent
7411ff2f69
commit
fe9db78120
|
@ -1099,6 +1099,10 @@ class ContextCache:
|
|||
return self.get_vtensor_type(
|
||||
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
elif isinstance(val, list) and all(
|
||||
isinstance(x, TorchFakeTensor) for x in val
|
||||
):
|
||||
return IrType.parse("!torch.list<vtensor>", context=self._c)
|
||||
|
||||
# Note that None is a valid scalar here, so it is important that this
|
||||
# is always checked as the last fallback.
|
||||
|
@ -1227,6 +1231,7 @@ class GraphNodeImporter:
|
|||
"_v",
|
||||
"_symbol_to_value",
|
||||
"_multi_result_nodes",
|
||||
"_unpack_list_values",
|
||||
"fx_importer",
|
||||
]
|
||||
|
||||
|
@ -1251,6 +1256,10 @@ class GraphNodeImporter:
|
|||
# Statically multi-result nodes which we have de-tupled are noted here.
|
||||
# They will have their getitem calls short-circuited.
|
||||
self._multi_result_nodes: Set[torch_fx.Node] = set()
|
||||
# If a OP returns a list, then it needs to be unpacked entirely using
|
||||
# prim.ListUnpack. Cache the result of these nodes so that it only
|
||||
# unpacks once instead of every time that getitem is used
|
||||
self._unpack_list_values: Dict[torch_fx.Node, Tuple[Value]] = {}
|
||||
|
||||
def bind_node_value(
|
||||
self,
|
||||
|
@ -1420,29 +1429,7 @@ class GraphNodeImporter:
|
|||
elif op == "call_function":
|
||||
target = node.target
|
||||
if target == operator.getitem:
|
||||
# Special case handling of getitem for when it is resolving
|
||||
# against a function call that we know has returned multiple
|
||||
# results. We short-circuit this case because we have modeled
|
||||
# function calls to natively return multiple results vs tupling.
|
||||
getitem_ref, getitem_index = node.args
|
||||
if getitem_ref in self._multi_result_nodes:
|
||||
try:
|
||||
self.bind_node_value(
|
||||
node,
|
||||
self.resolve_node_value(getitem_ref, getitem_index),
|
||||
)
|
||||
except IndexError:
|
||||
raise RuntimeError(
|
||||
f"getitem de-aliasing failed. This likely "
|
||||
f"indicates a programmer error that usually "
|
||||
f"would have happened at runtime. Please "
|
||||
f"notify developers if this case happens "
|
||||
f"(at {loc})."
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"General getitem access to non-multi-result ops"
|
||||
)
|
||||
self._import_getitem(loc, node)
|
||||
elif target in SYMBOLIC_TORCH_OPS or (
|
||||
is_symbolic(node.meta.get("val"))
|
||||
and is_builtin_function_or_method(target)
|
||||
|
@ -2007,6 +1994,51 @@ class GraphNodeImporter:
|
|||
with loc:
|
||||
return cvt(arg, self, self._cc)
|
||||
|
||||
def _import_getitem(self, loc: Location, node: torch.fx.Node):
|
||||
ref_node, index = node.args
|
||||
if ref_node in self._multi_result_nodes:
|
||||
# Special case handling of getitem for when it is resolving
|
||||
# against a function call that we know has returned multiple
|
||||
# results. We short-circuit this case because we have modeled
|
||||
# function calls to natively return multiple results vs tupling.
|
||||
try:
|
||||
self.bind_node_value(
|
||||
node,
|
||||
self.resolve_node_value(ref_node, index),
|
||||
)
|
||||
except IndexError:
|
||||
raise RuntimeError(
|
||||
f"getitem de-aliasing failed. This likely "
|
||||
f"indicates a programmer error that usually "
|
||||
f"would have happened at runtime. Please "
|
||||
f"notify developers if this case happens "
|
||||
f"(at {loc})."
|
||||
)
|
||||
else:
|
||||
# handle nodes that return a torch.list<...> at the MLIR level
|
||||
# NOTE: the length of the list must be knowable at compile time.
|
||||
if ref_node not in self._unpack_list_values:
|
||||
node_result = self.resolve_node_value(ref_node, 0)
|
||||
if str(node_result.type) in TORCH_LIST_TYPES:
|
||||
result_types = [
|
||||
self._cc.value_info_to_type(v) for v in ref_node.meta["val"]
|
||||
]
|
||||
operation = Operation.create(
|
||||
"torch.prim.ListUnpack",
|
||||
results=result_types,
|
||||
operands=[node_result],
|
||||
loc=loc,
|
||||
)
|
||||
self._unpack_list_values[ref_node] = tuple(operation.results)
|
||||
|
||||
try:
|
||||
self.bind_node_value(node, self._unpack_list_values[ref_node][index])
|
||||
except IndexError:
|
||||
raise RuntimeError(
|
||||
f"getitem failed. "
|
||||
f"getitem only supports lists of known length. (at {loc})"
|
||||
)
|
||||
|
||||
def _unpack_node_result_types(
|
||||
self, node: torch.fx.Node, schema: FunctionSchema
|
||||
) -> List[IrType]:
|
||||
|
@ -2337,6 +2369,10 @@ PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE = {
|
|||
"vtensor": "!torch.list<optional<vtensor>>",
|
||||
}
|
||||
|
||||
TORCH_LIST_TYPES = set(PY_TYPE_TO_TORCH_LIST_TYPE.values()) | set(
|
||||
PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE.values()
|
||||
)
|
||||
|
||||
SCALAR_TYPE_TO_TORCH_MLIR_TYPE = {
|
||||
torch.SymInt: "!torch.int",
|
||||
torch.SymFloat: "!torch.float",
|
||||
|
|
|
@ -84,3 +84,50 @@ def test_tanh_sigmoid_cat_custom_op():
|
|||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_custom_op_array_output
|
||||
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>)
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int
|
||||
# CHECK: %[[int:.+]] = torch.constant.int 4
|
||||
# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list<vtensor>
|
||||
# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list<vtensor> -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[V1]]#0, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[V1]]#1, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[V1]]#2, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[V1]]#3, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: return %[[V1]]#0, %[[V1]]#1, %[[V1]]#2, %[[V1]]#3 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>
|
||||
def test_custom_op_array_output():
|
||||
m = Library("my_custom_library", "DEF")
|
||||
m.define("array_output_op(int num_outs, Tensor a) -> Tensor[]")
|
||||
|
||||
@impl(m, "array_output_op", "CompositeExplicitAutograd")
|
||||
def custom_op(num_outs, a):
|
||||
return [a] * num_outs
|
||||
|
||||
@impl_abstract("my_custom_library::array_output_op")
|
||||
def custom_op_meta(num_outs, a):
|
||||
result = custom_op(num_outs, a)
|
||||
return [torch.empty_like(t) for t in result]
|
||||
|
||||
class ArrayOutputCustomOp(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.my_custom_library.array_output_op(4, a)
|
||||
|
||||
dim = Dim("n", max=10)
|
||||
dynamic_shapes = {
|
||||
"a": {0: dim},
|
||||
}
|
||||
|
||||
a = torch.rand(2, 3)
|
||||
m = fx.export_and_import(
|
||||
ArrayOutputCustomOp(),
|
||||
a,
|
||||
import_symbolic_shape_expressions=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
print(m)
|
||||
|
|
Loading…
Reference in New Issue