mirror of https://github.com/llvm/torch-mlir
[NFC] Update black version (#3256)
* Update black version to support 3.11/3.12 * Reformat codepull/3257/head
parent
aed2cf3351
commit
b2185195e8
|
@ -11,7 +11,7 @@ repos:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 22.10.0
|
rev: 24.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
See https://github.com/llvm/torch-mlir/issues/1374
|
See https://github.com/llvm/torch-mlir/issues/1374
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from torch_mlir import torchscript
|
||||||
|
|
||||||
from transformers import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
|
|
||||||
|
|
||||||
# Wrap the bert model to avoid multiple returns problem
|
# Wrap the bert model to avoid multiple returns problem
|
||||||
class BertTinyWrapper(torch.nn.Module):
|
class BertTinyWrapper(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
|
@ -257,9 +257,9 @@ class _FXGraphImporter:
|
||||||
# FakeTensor's in case of a tuple return with multiple elements.
|
# FakeTensor's in case of a tuple return with multiple elements.
|
||||||
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
||||||
self._module = ir.Module.create(ir.Location.unknown())
|
self._module = ir.Module.create(ir.Location.unknown())
|
||||||
self._module.operation.attributes[
|
self._module.operation.attributes["torch.debug_module_name"] = (
|
||||||
"torch.debug_module_name"
|
ir.StringAttr.get(func_name)
|
||||||
] = ir.StringAttr.get(func_name)
|
)
|
||||||
function_type = _extract_function_type_from_graph(g)
|
function_type = _extract_function_type_from_graph(g)
|
||||||
func = func_dialect.FuncOp(
|
func = func_dialect.FuncOp(
|
||||||
func_name,
|
func_name,
|
||||||
|
|
|
@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
(ns, unqual + "_", overload if not is_functional_op else "")
|
(ns, unqual + "_", overload if not is_functional_op else "")
|
||||||
),
|
),
|
||||||
emitter_td,
|
emitter_td,
|
||||||
traits=["IsTrailingUnderscoreInplaceVariant"]
|
traits=(
|
||||||
if not is_functional_op
|
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
|
||||||
else [],
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
|
|
|
@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
|
||||||
examples = []
|
examples = []
|
||||||
input_names = []
|
input_names = []
|
||||||
dynamic_tensors = {}
|
dynamic_tensors = {}
|
||||||
for (index, arg) in enumerate(inputs):
|
for index, arg in enumerate(inputs):
|
||||||
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
||||||
shape = tuple(shape)
|
shape = tuple(shape)
|
||||||
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
||||||
|
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
|
||||||
input_names.append(input_name)
|
input_names.append(input_name)
|
||||||
|
|
||||||
dynamic_dims = {}
|
dynamic_dims = {}
|
||||||
for (dimindex, dim) in enumerate(arg.shape):
|
for dimindex, dim in enumerate(arg.shape):
|
||||||
if dim < 0:
|
if dim < 0:
|
||||||
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
||||||
|
|
||||||
|
|
|
@ -101,11 +101,13 @@ class RefBackendInvoker:
|
||||||
def consume_return_funcs(*args):
|
def consume_return_funcs(*args):
|
||||||
self.result = tuple(
|
self.result = tuple(
|
||||||
[
|
[
|
||||||
|
(
|
||||||
arg
|
arg
|
||||||
if type in elemental_type_to_ctype
|
if type in elemental_type_to_ctype
|
||||||
else unranked_memref_to_numpy(
|
else unranked_memref_to_numpy(
|
||||||
arg, memref_type_to_np_dtype[type]
|
arg, memref_type_to_np_dtype[type]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
for arg, type in zip(args, ret_types)
|
for arg, type in zip(args, ret_types)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: QuantizedReluInt32())
|
@register_test_case(module_factory=lambda: QuantizedReluInt32())
|
||||||
def QuantizedReluInt32_basic(module, tu: TestUtils):
|
def QuantizedReluInt32_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
|
||||||
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
|
@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
||||||
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
||||||
class SliceScatterModule(torch.nn.Module):
|
class SliceScatterModule(torch.nn.Module):
|
||||||
|
|
|
@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
|
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -18,6 +18,7 @@ mb = ModuleBuilder()
|
||||||
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
|
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
|
||||||
# naively duplicating a Tensor retains the identity of the TensorImpl.
|
# naively duplicating a Tensor retains the identity of the TensorImpl.
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.add3
|
# CHECK-LABEL: func.func @__torch__.add3
|
||||||
# Note that line-level debug information for parts unannotated in the Torch
|
# Note that line-level debug information for parts unannotated in the Torch
|
||||||
# graph are ascribed to the first op that carries source information. Presently
|
# graph are ascribed to the first op that carries source information. Presently
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: @__torch__.f
|
# CHECK-LABEL: @__torch__.f
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.optional_return(
|
# CHECK-LABEL: func.func @__torch__.optional_return(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
|
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
|
||||||
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
|
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
|
||||||
|
|
|
@ -13,6 +13,7 @@ mb = ModuleBuilder()
|
||||||
# else branch and making all defined values optional, so no special handling
|
# else branch and making all defined values optional, so no special handling
|
||||||
# is needed.
|
# is needed.
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: @__torch__.prim_If(
|
# CHECK-LABEL: @__torch__.prim_If(
|
||||||
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
||||||
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
|
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
|
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
|
||||||
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
|
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
|
||||||
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
|
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
|
||||||
|
|
|
@ -15,6 +15,7 @@ import typing
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
|
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
||||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
||||||
|
|
|
@ -13,6 +13,7 @@ from utils import create_script_function
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
|
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.tuple(
|
# CHECK-LABEL: func.func @__torch__.tuple(
|
||||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK: @__torch__.returns_bool
|
# CHECK: @__torch__.returns_bool
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK: @__torch__.returns_none
|
# CHECK: @__torch__.returns_none
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch._C import CompilationUnit
|
||||||
|
|
||||||
# RUN: %PYTHON %s
|
# RUN: %PYTHON %s
|
||||||
|
|
||||||
|
|
||||||
# Import TorchScript IR string as ScriptFunction.
|
# Import TorchScript IR string as ScriptFunction.
|
||||||
def create_script_function(func_name, ts_ir_str, **kwargs):
|
def create_script_function(func_name, ts_ir_str, **kwargs):
|
||||||
cu = CompilationUnit()
|
cu = CompilationUnit()
|
||||||
|
|
|
@ -1849,8 +1849,7 @@ def _emit_operation(
|
||||||
|
|
||||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||||
# may have a different meaning.
|
# may have a different meaning.
|
||||||
class EmptyType:
|
class EmptyType: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
Empty = EmptyType()
|
Empty = EmptyType()
|
||||||
|
|
|
@ -156,8 +156,7 @@ class GraphInfo:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class OnnxImportError(Exception):
|
class OnnxImportError(Exception): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class NodeImporter:
|
class NodeImporter:
|
||||||
|
@ -235,22 +234,22 @@ class NodeImporter:
|
||||||
else:
|
else:
|
||||||
default_opset_version = opset_import.version
|
default_opset_version = opset_import.version
|
||||||
if default_opset_version:
|
if default_opset_version:
|
||||||
container_op.attributes[
|
container_op.attributes["torch.onnx_meta.opset_version"] = (
|
||||||
"torch.onnx_meta.opset_version"
|
IntegerAttr.get(i64_type, default_opset_version)
|
||||||
] = IntegerAttr.get(i64_type, default_opset_version)
|
)
|
||||||
if opset_versions:
|
if opset_versions:
|
||||||
container_op.attributes[
|
container_op.attributes["torch.onnx_meta.opset_versions"] = (
|
||||||
"torch.onnx_meta.opset_versions"
|
DictAttr.get(opset_versions)
|
||||||
] = DictAttr.get(opset_versions)
|
)
|
||||||
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
|
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
|
||||||
IntegerType.get_signed(64), m.ir_version
|
IntegerType.get_signed(64), m.ir_version
|
||||||
)
|
)
|
||||||
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
|
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
|
||||||
m.producer_name
|
m.producer_name
|
||||||
)
|
)
|
||||||
container_op.attributes[
|
container_op.attributes["torch.onnx_meta.producer_version"] = (
|
||||||
"torch.onnx_meta.producer_version"
|
StringAttr.get(m.producer_version)
|
||||||
] = StringAttr.get(m.producer_version)
|
)
|
||||||
|
|
||||||
def import_all(self, func=True):
|
def import_all(self, func=True):
|
||||||
"""Imports all nodes topologically."""
|
"""Imports all nodes topologically."""
|
||||||
|
@ -658,9 +657,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
|
||||||
RankedTensorType.get(shape, IntegerType.get_signed(64)),
|
RankedTensorType.get(shape, IntegerType.get_signed(64)),
|
||||||
IntegerAttr.get(
|
IntegerAttr.get(
|
||||||
IntegerType.get_signed(64),
|
IntegerType.get_signed(64),
|
||||||
|
(
|
||||||
int.from_bytes(tp.raw_data, "little", signed=True)
|
int.from_bytes(tp.raw_data, "little", signed=True)
|
||||||
if tp.HasField("raw_data")
|
if tp.HasField("raw_data")
|
||||||
else tp.int64_data[0],
|
else tp.int64_data[0]
|
||||||
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
|
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
|
||||||
|
@ -703,7 +704,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
|
||||||
),
|
),
|
||||||
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
|
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
|
||||||
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
|
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
|
||||||
)
|
),
|
||||||
# Intentionally unsupported: STRING
|
# Intentionally unsupported: STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue