[NFC] Update black version (#3256)

* Update black version to support 3.11/3.12
* Reformat code
pull/3257/head
penguin_wwy 2024-04-29 11:06:01 +08:00 committed by GitHub
parent aed2cf3351
commit b2185195e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 49 additions and 33 deletions

View File

@ -11,7 +11,7 @@ repos:
- id: check-yaml - id: check-yaml
- id: check-added-large-files - id: check-added-large-files
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 22.10.0 rev: 24.4.2
hooks: hooks:
- id: black - id: black

View File

@ -2,6 +2,7 @@
See https://github.com/llvm/torch-mlir/issues/1374 See https://github.com/llvm/torch-mlir/issues/1374
""" """
import argparse import argparse
import json import json

View File

@ -3,6 +3,7 @@ from torch_mlir import torchscript
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
# Wrap the bert model to avoid multiple returns problem # Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module): class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -257,9 +257,9 @@ class _FXGraphImporter:
# FakeTensor's in case of a tuple return with multiple elements. # FakeTensor's in case of a tuple return with multiple elements.
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {} self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
self._module = ir.Module.create(ir.Location.unknown()) self._module = ir.Module.create(ir.Location.unknown())
self._module.operation.attributes[ self._module.operation.attributes["torch.debug_module_name"] = (
"torch.debug_module_name" ir.StringAttr.get(func_name)
] = ir.StringAttr.get(func_name) )
function_type = _extract_function_type_from_graph(g) function_type = _extract_function_type_from_graph(g)
func = func_dialect.FuncOp( func = func_dialect.FuncOp(
func_name, func_name,

View File

@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
(ns, unqual + "_", overload if not is_functional_op else "") (ns, unqual + "_", overload if not is_functional_op else "")
), ),
emitter_td, emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"] traits=(
if not is_functional_op ["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
else [], ),
) )
# ========================================================================== # ==========================================================================

View File

@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
examples = [] examples = []
input_names = [] input_names = []
dynamic_tensors = {} dynamic_tensors = {}
for (index, arg) in enumerate(inputs): for index, arg in enumerate(inputs):
shape = map(lambda d: d if d >= 0 else 1, arg.shape) shape = map(lambda d: d if d >= 0 else 1, arg.shape)
shape = tuple(shape) shape = tuple(shape)
examples.append(torch.zeros(size=shape, dtype=arg.dtype)) examples.append(torch.zeros(size=shape, dtype=arg.dtype))
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
input_names.append(input_name) input_names.append(input_name)
dynamic_dims = {} dynamic_dims = {}
for (dimindex, dim) in enumerate(arg.shape): for dimindex, dim in enumerate(arg.shape):
if dim < 0: if dim < 0:
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)

View File

@ -101,11 +101,13 @@ class RefBackendInvoker:
def consume_return_funcs(*args): def consume_return_funcs(*args):
self.result = tuple( self.result = tuple(
[ [
(
arg arg
if type in elemental_type_to_ctype if type in elemental_type_to_ctype
else unranked_memref_to_numpy( else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type] arg, memref_type_to_np_dtype[type]
) )
)
for arg, type in zip(args, ret_types) for arg, type in zip(args, ret_types)
] ]
) )

View File

@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
@register_test_case(module_factory=lambda: QuantizedReluInt32()) @register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils): def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward( module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
)
# ============================================================================== # ==============================================================================

View File

@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1). # For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index). # For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module): class SliceScatterModule(torch.nn.Module):

View File

@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK: module attributes {torch.debug_module_name = "TestModule"} # CHECK: module attributes {torch.debug_module_name = "TestModule"}
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -18,6 +18,7 @@ mb = ModuleBuilder()
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so # `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
# naively duplicating a Tensor retains the identity of the TensorImpl. # naively duplicating a Tensor retains the identity of the TensorImpl.
# CHECK-LABEL: torch.class_type @__torch__.TestModule { # CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: torch.class_type @__torch__.TestModule { # CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.add3 # CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch # Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently # graph are ascribed to the first op that carries source information. Presently

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: @__torch__.f # CHECK-LABEL: @__torch__.f
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.optional_return( # CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int> # CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>

View File

@ -13,6 +13,7 @@ mb = ModuleBuilder()
# else branch and making all defined values optional, so no special handling # else branch and making all defined values optional, so no special handling
# is needed. # is needed.
# CHECK-LABEL: @__torch__.prim_If( # CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int { # CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike( # CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float { # CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true # CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true

View File

@ -15,6 +15,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor( # CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor # CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor

View File

@ -13,6 +13,7 @@ from utils import create_script_function
mb = ModuleBuilder() mb = ModuleBuilder()
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])]) NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
# CHECK-LABEL: func.func @__torch__.tuple( # CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> # CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK: @__torch__.returns_bool # CHECK: @__torch__.returns_bool
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK: @__torch__.returns_none # CHECK: @__torch__.returns_none
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script

View File

@ -9,6 +9,7 @@ from torch._C import CompilationUnit
# RUN: %PYTHON %s # RUN: %PYTHON %s
# Import TorchScript IR string as ScriptFunction. # Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str, **kwargs): def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit() cu = CompilationUnit()

View File

@ -1849,8 +1849,7 @@ def _emit_operation(
# Opaque value to indicate something is empty. Used in cases where 'None' # Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning. # may have a different meaning.
class EmptyType: class EmptyType: ...
...
Empty = EmptyType() Empty = EmptyType()

View File

@ -156,8 +156,7 @@ class GraphInfo:
return "" return ""
class OnnxImportError(Exception): class OnnxImportError(Exception): ...
...
class NodeImporter: class NodeImporter:
@ -235,22 +234,22 @@ class NodeImporter:
else: else:
default_opset_version = opset_import.version default_opset_version = opset_import.version
if default_opset_version: if default_opset_version:
container_op.attributes[ container_op.attributes["torch.onnx_meta.opset_version"] = (
"torch.onnx_meta.opset_version" IntegerAttr.get(i64_type, default_opset_version)
] = IntegerAttr.get(i64_type, default_opset_version) )
if opset_versions: if opset_versions:
container_op.attributes[ container_op.attributes["torch.onnx_meta.opset_versions"] = (
"torch.onnx_meta.opset_versions" DictAttr.get(opset_versions)
] = DictAttr.get(opset_versions) )
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version IntegerType.get_signed(64), m.ir_version
) )
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name m.producer_name
) )
container_op.attributes[ container_op.attributes["torch.onnx_meta.producer_version"] = (
"torch.onnx_meta.producer_version" StringAttr.get(m.producer_version)
] = StringAttr.get(m.producer_version) )
def import_all(self, func=True): def import_all(self, func=True):
"""Imports all nodes topologically.""" """Imports all nodes topologically."""
@ -658,9 +657,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
RankedTensorType.get(shape, IntegerType.get_signed(64)), RankedTensorType.get(shape, IntegerType.get_signed(64)),
IntegerAttr.get( IntegerAttr.get(
IntegerType.get_signed(64), IntegerType.get_signed(64),
(
int.from_bytes(tp.raw_data, "little", signed=True) int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data") if tp.HasField("raw_data")
else tp.int64_data[0], else tp.int64_data[0]
),
), ),
), ),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
@ -703,7 +704,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
), ),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
) ),
# Intentionally unsupported: STRING # Intentionally unsupported: STRING
} }