mirror of https://github.com/llvm/torch-mlir
Add handling of namespaces to library generator (#2391)
When using custom ops, sometimes PyTorch will insert namespaces to the abstract interpretation function name in the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}`. The extra namespaces are not part of the abstract interpretation function name, so it needs to be removed before generating the library of MLIR snippets of abstract interpretation functions. This commit adds support for removing the namespace information.pull/2393/head snapshot-20230812.928
parent
23d7821afa
commit
ff762100b8
|
@ -6,6 +6,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Union, Any, Dict
|
from typing import List, Optional, Union, Any, Dict
|
||||||
|
import codecs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -234,10 +235,22 @@ def generate_library(functions: Dict[str, Any]) -> str:
|
||||||
# defined symbols. Since all of our shape functions conveniently have
|
# defined symbols. Since all of our shape functions conveniently have
|
||||||
# a `〇` in them, we replace the torch namespace with our prefix. E.g.:
|
# a `〇` in them, we replace the torch namespace with our prefix. E.g.:
|
||||||
# __torch__.aten〇add〇Scalar -> __torch_mlir_shape_fn.aten〇add〇Scalar
|
# __torch__.aten〇add〇Scalar -> __torch_mlir_shape_fn.aten〇add〇Scalar
|
||||||
asm = re.sub(r"__torch__\.([^.(]+)\\E3\\80\\87([^.(]+)\\E3\\80\\A1([^.(\"]+)",
|
|
||||||
r"__torch_mlir_\3_fn.\1\\E3\\80\\87\2",
|
# Encoding for: 〇
|
||||||
|
circle = r"\\E3\\80\\87"
|
||||||
|
# Encoding for: 〡
|
||||||
|
line = r"\\E3\\80\\A1"
|
||||||
|
name = r"[^.(]+"
|
||||||
|
# Sometimes PyTorch will insert namespaces to the function name in
|
||||||
|
# the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}`
|
||||||
|
# The extra namespaces are not part of the abstract interpretation
|
||||||
|
# function name, so here we simply drop the extra namespaces.
|
||||||
|
namespace = fr"(?:{name}\.)"
|
||||||
|
|
||||||
|
asm = re.sub(fr'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"',
|
||||||
|
fr'@"__torch_mlir_\3_fn.\1{circle}\2"',
|
||||||
asm)
|
asm)
|
||||||
|
|
||||||
# Put the `〇` back to a regular `.`.
|
# Put the `〇` back to a regular `.`.
|
||||||
asm = asm.replace("\\E3\\80\\87", ".")
|
asm = asm.replace(codecs.decode(circle, "unicode_escape"), ".")
|
||||||
return asm
|
return asm
|
||||||
|
|
|
@ -3,6 +3,7 @@ import tempfile
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
from torch_mlir_e2e_test.annotations import export, annotate_args
|
from torch_mlir_e2e_test.annotations import export, annotate_args
|
||||||
|
@ -51,6 +52,10 @@ class CustomOpExampleModule(torch.nn.Module):
|
||||||
mod = CustomOpExampleModule()
|
mod = CustomOpExampleModule()
|
||||||
mod.eval()
|
mod.eval()
|
||||||
|
|
||||||
|
def run():
|
||||||
|
mod = CustomOpExampleModule()
|
||||||
|
mod.eval()
|
||||||
|
|
||||||
module = torch_mlir.compile(
|
module = torch_mlir.compile(
|
||||||
mod,
|
mod,
|
||||||
torch.ones(3, 4),
|
torch.ones(3, 4),
|
||||||
|
@ -61,6 +66,27 @@ module = torch_mlir.compile(
|
||||||
|
|
||||||
print(module)
|
print(module)
|
||||||
|
|
||||||
|
run()
|
||||||
|
|
||||||
|
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
||||||
|
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
# CHECK: %{{.*}} = torch.constant.int 2
|
||||||
|
# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||||
|
# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||||
|
# CHECK: return %1 : !torch.vtensor<[3,4],f32>
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
# Using `torch.multiprocessing` adds extra namespaces to the abstract
|
||||||
|
# interpretation functions when they are imported into MLIR:
|
||||||
|
# `func @"__torch__.__mp_main__.{name}...`
|
||||||
|
# This tests that the extra namespaces are removed correctly.
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mp.set_start_method("spawn")
|
||||||
|
p = mp.Process(target=run, args=())
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
|
||||||
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
||||||
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
# CHECK: %{{.*}} = torch.constant.int 2
|
# CHECK: %{{.*}} = torch.constant.int 2
|
||||||
|
|
Loading…
Reference in New Issue