[NFC] Update black version (#3256)

* Update black version to support 3.11/3.12
* Reformat code
pull/3257/head
penguin_wwy 2024-04-29 11:06:01 +08:00 committed by GitHub
parent aed2cf3351
commit b2185195e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 49 additions and 33 deletions

View File

@ -11,7 +11,7 @@ repos:
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.4.2
hooks:
- id: black

View File

@ -2,6 +2,7 @@
See https://github.com/llvm/torch-mlir/issues/1374
"""
import argparse
import json

View File

@ -3,6 +3,7 @@ from torch_mlir import torchscript
from transformers import BertForMaskedLM
# Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None:

View File

@ -257,9 +257,9 @@ class _FXGraphImporter:
# FakeTensor's in case of a tuple return with multiple elements.
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
self._module = ir.Module.create(ir.Location.unknown())
self._module.operation.attributes[
"torch.debug_module_name"
] = ir.StringAttr.get(func_name)
self._module.operation.attributes["torch.debug_module_name"] = (
ir.StringAttr.get(func_name)
)
function_type = _extract_function_type_from_graph(g)
func = func_dialect.FuncOp(
func_name,

View File

@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
(ns, unqual + "_", overload if not is_functional_op else "")
),
emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"]
if not is_functional_op
else [],
traits=(
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
),
)
# ==========================================================================

View File

@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
examples = []
input_names = []
dynamic_tensors = {}
for (index, arg) in enumerate(inputs):
for index, arg in enumerate(inputs):
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
shape = tuple(shape)
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
input_names.append(input_name)
dynamic_dims = {}
for (dimindex, dim) in enumerate(arg.shape):
for dimindex, dim in enumerate(arg.shape):
if dim < 0:
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)

View File

@ -101,11 +101,13 @@ class RefBackendInvoker:
def consume_return_funcs(*args):
self.result = tuple(
[
(
arg
if type in elemental_type_to_ctype
else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type]
)
)
for arg, type in zip(args, ret_types)
]
)

View File

@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
@register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward(
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
)
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
# ==============================================================================

View File

@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module):

View File

@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -18,6 +18,7 @@ mb = ModuleBuilder()
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
# naively duplicating a Tensor retains the identity of the TensorImpl.
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: @__torch__.f
@mb.import_function
@torch.jit.script

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>

View File

@ -13,6 +13,7 @@ mb = ModuleBuilder()
# else branch and making all defined values optional, so no special handling
# is needed.
# CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true

View File

@ -15,6 +15,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor

View File

@ -13,6 +13,7 @@ from utils import create_script_function
mb = ModuleBuilder()
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
# CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK: @__torch__.returns_bool
@mb.import_function
@torch.jit.script

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK: @__torch__.returns_none
@mb.import_function
@torch.jit.script

View File

@ -9,6 +9,7 @@ from torch._C import CompilationUnit
# RUN: %PYTHON %s
# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit()

View File

@ -1849,8 +1849,7 @@ def _emit_operation(
# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...
class EmptyType: ...
Empty = EmptyType()

View File

@ -156,8 +156,7 @@ class GraphInfo:
return ""
class OnnxImportError(Exception):
...
class OnnxImportError(Exception): ...
class NodeImporter:
@ -235,22 +234,22 @@ class NodeImporter:
else:
default_opset_version = opset_import.version
if default_opset_version:
container_op.attributes[
"torch.onnx_meta.opset_version"
] = IntegerAttr.get(i64_type, default_opset_version)
container_op.attributes["torch.onnx_meta.opset_version"] = (
IntegerAttr.get(i64_type, default_opset_version)
)
if opset_versions:
container_op.attributes[
"torch.onnx_meta.opset_versions"
] = DictAttr.get(opset_versions)
container_op.attributes["torch.onnx_meta.opset_versions"] = (
DictAttr.get(opset_versions)
)
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version
)
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name
)
container_op.attributes[
"torch.onnx_meta.producer_version"
] = StringAttr.get(m.producer_version)
container_op.attributes["torch.onnx_meta.producer_version"] = (
StringAttr.get(m.producer_version)
)
def import_all(self, func=True):
"""Imports all nodes topologically."""
@ -658,9 +657,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
RankedTensorType.get(shape, IntegerType.get_signed(64)),
IntegerAttr.get(
IntegerType.get_signed(64),
(
int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data")
else tp.int64_data[0],
else tp.int64_data[0]
),
),
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
@ -703,7 +704,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
)
),
# Intentionally unsupported: STRING
}