mirror of https://github.com/llvm/torch-mlir
[ods] Allow all tensor returns to be optional. (#3082)
This was found while tracing backwards graphs: the convolution_backwards op will return None if the first result is not needed. Confirmed by defining a custom op with a `Tensor` return signature and having its meta kernel return None.pull/3070/head
parent
40008b025a
commit
6d680ff445
File diff suppressed because it is too large
Load Diff
|
@ -60,7 +60,13 @@ TORCH_NON_VALUE_TYPE_TO_ODS_TYPE = {
|
|||
}
|
||||
|
||||
|
||||
def get_ods_type(type: str, non_value: bool):
|
||||
def get_ods_type(type: str, non_value: bool, *, is_result: bool = False):
|
||||
# In torch signatures, it is legal to return None for a Tensor, so we
|
||||
# silently upgrade Tensor results to Tensor?. I'm not sure this is written
|
||||
# anywhere, but Torch does use it sometimes internally. As an example,
|
||||
# the first return of a convolution_backwards may be returned as None.
|
||||
if is_result and type == "Tensor":
|
||||
type = "Tensor?"
|
||||
# TODO: Increase precision on dict type modeling.
|
||||
if type.startswith("Dict("):
|
||||
type = "Dict"
|
||||
|
@ -158,7 +164,7 @@ def raw_emit_op(operator: JitOperator,
|
|||
p_td("Variadic<AnyTorchType>:$results")
|
||||
else:
|
||||
p_td(",\n".join([
|
||||
f"""{get_ods_type(ret["type"], is_non_value_op)}:${ret["name"] or generic_result_name(e)}"""
|
||||
f"""{get_ods_type(ret["type"], is_non_value_op, is_result=True)}:${ret["name"] or generic_result_name(e)}"""
|
||||
for e, ret in enumerate(operator.returns)
|
||||
]))
|
||||
p_td(");")
|
||||
|
|
Loading…
Reference in New Issue