Fix naming of results in ODS generator

This commit fixes the naming of results in the torch ODS generator
when dealing with multiple results. In particular, this commit appends
an index to each result name to guarantee that they are all unique.
pull/487/head
Ramiro Leal-Cavazos 2021-12-14 21:32:10 +00:00
parent 829cf8afc3
commit 707c113463
2 changed files with 13 additions and 8 deletions

View File

@ -1391,11 +1391,11 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$layer_norm,
AnyTorchTensorType:$mean,
AnyTorchTensorType:$variance
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)";
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($result0) `,` type($result1) `,` type($result2)";
}
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [

View File

@ -297,6 +297,11 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
emitter = TextEmitter(f)
p = lambda *args: emitter.print(*args)
op_name, td_def_name = operator.get_mlir_names()
# Generate unique result names for ops with nameless results
multiple_results = len(operator.returns) > 1
generic_result_name = lambda i: "result" + (str(i) if multiple_results else "")
p(f"def {td_def_name} : Torch_Op<{emitter.quote(op_name)}, [")
with emitter.indent():
with emitter.indent():
@ -321,8 +326,8 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
p("Variadic<AnyTorchType>:$results")
else:
p(",\n".join([
f"""{get_ods_type(ret["type"])}:${ret["name"] or "result"}"""
for ret in operator.returns
f"""{get_ods_type(ret["type"])}:${ret["name"] or generic_result_name(e)}"""
for e, ret in enumerate(operator.returns)
]))
p(");")
@ -338,8 +343,8 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
assembly_result_types = "type($results)"
else:
assembly_result_types = " `,` ".join(
f"""type(${ret["name"] or "result"})"""
for ret in operator.returns)
f"""type(${ret["name"] or generic_result_name(e)})"""
for e, ret in enumerate(operator.returns))
if assembly_operand_types and assembly_result_types:
maybe_arrow = " `->` "
else: