Add handling of namespaces to library generator (#2391)

When using custom ops, sometimes PyTorch will insert namespaces to the
abstract interpretation function name in the format:
`__torch__.{namespace_1}.{namespace_2}...{op_name}`.  The extra
namespaces are not part of the abstract interpretation function name,
so it needs to be removed before generating the library of MLIR
snippets of abstract interpretation functions. This commit adds
support for removing the namespace information.
pull/2393/head snapshot-20230812.928
Ramiro Leal-Cavazos 2023-08-11 09:56:19 -07:00 committed by GitHub
parent 23d7821afa
commit ff762100b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 11 deletions

View File

@ -6,6 +6,7 @@
import inspect import inspect
import re import re
from typing import List, Optional, Union, Any, Dict from typing import List, Optional, Union, Any, Dict
import codecs
import torch import torch
@ -234,10 +235,22 @@ def generate_library(functions: Dict[str, Any]) -> str:
# defined symbols. Since all of our shape functions conveniently have # defined symbols. Since all of our shape functions conveniently have
# a `` in them, we replace the torch namespace with our prefix. E.g.: # a `` in them, we replace the torch namespace with our prefix. E.g.:
# __torch__.atenaddScalar -> __torch_mlir_shape_fn.atenaddScalar # __torch__.atenaddScalar -> __torch_mlir_shape_fn.atenaddScalar
asm = re.sub(r"__torch__\.([^.(]+)\\E3\\80\\87([^.(]+)\\E3\\80\\A1([^.(\"]+)",
r"__torch_mlir_\3_fn.\1\\E3\\80\\87\2", # Encoding for:
circle = r"\\E3\\80\\87"
# Encoding for: 〡
line = r"\\E3\\80\\A1"
name = r"[^.(]+"
# Sometimes PyTorch will insert namespaces to the function name in
# the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}`
# The extra namespaces are not part of the abstract interpretation
# function name, so here we simply drop the extra namespaces.
namespace = fr"(?:{name}\.)"
asm = re.sub(fr'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"',
fr'@"__torch_mlir_\3_fn.\1{circle}\2"',
asm) asm)
# Put the `` back to a regular `.`. # Put the `` back to a regular `.`.
asm = asm.replace("\\E3\\80\\87", ".") asm = asm.replace(codecs.decode(circle, "unicode_escape"), ".")
return asm return asm

View File

@ -3,6 +3,7 @@ import tempfile
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import torch.multiprocessing as mp
import torch.utils.cpp_extension import torch.utils.cpp_extension
import torch_mlir import torch_mlir
from torch_mlir_e2e_test.annotations import export, annotate_args from torch_mlir_e2e_test.annotations import export, annotate_args
@ -51,15 +52,40 @@ class CustomOpExampleModule(torch.nn.Module):
mod = CustomOpExampleModule() mod = CustomOpExampleModule()
mod.eval() mod.eval()
module = torch_mlir.compile( def run():
mod, mod = CustomOpExampleModule()
torch.ones(3, 4), mod.eval()
output_type="torch",
backend_legal_ops=["goofy.identity"],
extra_library=extra_library,
)
print(module) module = torch_mlir.compile(
mod,
torch.ones(3, 4),
output_type="torch",
backend_legal_ops=["goofy.identity"],
extra_library=extra_library,
)
print(module)
run()
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
# CHECK: %{{.*}} = torch.constant.int 2
# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK: return %1 : !torch.vtensor<[3,4],f32>
# CHECK: }
# CHECK: }
# Using `torch.multiprocessing` adds extra namespaces to the abstract
# interpretation functions when they are imported into MLIR:
# `func @"__torch__.__mp_main__.{name}...`
# This tests that the extra namespaces are removed correctly.
if __name__ == "__main__":
mp.set_start_method("spawn")
p = mp.Process(target=run, args=())
p.start()
p.join()
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} { # CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { # CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {