[ods] Allow all tensor returns to be optional. (#3082)

This was found while tracing backwards graphs: the convolution_backwards
op will return None if the first result is not needed. Confirmed by
defining a custom op with a `Tensor` return signature and having its
meta kernel return None.
pull/3070/head
Stella Laurenzo 2024-03-29 23:09:34 -07:00 committed by GitHub
parent 40008b025a
commit 6d680ff445
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 571 additions and 565 deletions

File diff suppressed because it is too large Load Diff

View File

@ -60,7 +60,13 @@ TORCH_NON_VALUE_TYPE_TO_ODS_TYPE = {
}
def get_ods_type(type: str, non_value: bool):
def get_ods_type(type: str, non_value: bool, *, is_result: bool = False):
# In torch signatures, it is legal to return None for a Tensor, so we
# silently upgrade Tensor results to Tensor?. I'm not sure this is written
# anywhere, but Torch does use it sometimes internally. As an example,
# the first return of a convolution_backwards may be returned as None.
if is_result and type == "Tensor":
type = "Tensor?"
# TODO: Increase precision on dict type modeling.
if type.startswith("Dict("):
type = "Dict"
@ -158,7 +164,7 @@ def raw_emit_op(operator: JitOperator,
p_td("Variadic<AnyTorchType>:$results")
else:
p_td(",\n".join([
f"""{get_ods_type(ret["type"], is_non_value_op)}:${ret["name"] or generic_result_name(e)}"""
f"""{get_ods_type(ret["type"], is_non_value_op, is_result=True)}:${ret["name"] or generic_result_name(e)}"""
for e, ret in enumerate(operator.returns)
]))
p_td(");")