mirror of https://github.com/llvm/torch-mlir
Fix naming of results in ODS generator
This commit fixes the naming of results in the torch ODS generator when dealing with multiple results. In particular, this commit appends an index to each result name to guarantee that they are all unique.pull/487/head
parent
829cf8afc3
commit
707c113463
|
@ -1391,11 +1391,11 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
|||
Torch_FloatType:$eps
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$layer_norm,
|
||||
AnyTorchTensorType:$mean,
|
||||
AnyTorchTensorType:$variance
|
||||
AnyTorchTensorType:$result0,
|
||||
AnyTorchTensorType:$result1,
|
||||
AnyTorchTensorType:$result2
|
||||
);
|
||||
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)";
|
||||
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($result0) `,` type($result1) `,` type($result2)";
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||
|
|
|
@ -297,6 +297,11 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
|||
emitter = TextEmitter(f)
|
||||
p = lambda *args: emitter.print(*args)
|
||||
op_name, td_def_name = operator.get_mlir_names()
|
||||
|
||||
# Generate unique result names for ops with nameless results
|
||||
multiple_results = len(operator.returns) > 1
|
||||
generic_result_name = lambda i: "result" + (str(i) if multiple_results else "")
|
||||
|
||||
p(f"def {td_def_name} : Torch_Op<{emitter.quote(op_name)}, [")
|
||||
with emitter.indent():
|
||||
with emitter.indent():
|
||||
|
@ -321,8 +326,8 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
|||
p("Variadic<AnyTorchType>:$results")
|
||||
else:
|
||||
p(",\n".join([
|
||||
f"""{get_ods_type(ret["type"])}:${ret["name"] or "result"}"""
|
||||
for ret in operator.returns
|
||||
f"""{get_ods_type(ret["type"])}:${ret["name"] or generic_result_name(e)}"""
|
||||
for e, ret in enumerate(operator.returns)
|
||||
]))
|
||||
p(");")
|
||||
|
||||
|
@ -338,8 +343,8 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
|||
assembly_result_types = "type($results)"
|
||||
else:
|
||||
assembly_result_types = " `,` ".join(
|
||||
f"""type(${ret["name"] or "result"})"""
|
||||
for ret in operator.returns)
|
||||
f"""type(${ret["name"] or generic_result_name(e)})"""
|
||||
for e, ret in enumerate(operator.returns))
|
||||
if assembly_operand_types and assembly_result_types:
|
||||
maybe_arrow = " `->` "
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue