mirror of https://github.com/llvm/torch-mlir
[NFC] Update black version (#3256)
* Update black version to support 3.11/3.12 * Reformat codepull/3257/head
parent
aed2cf3351
commit
b2185195e8
|
@ -11,7 +11,7 @@ repos:
|
|||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
See https://github.com/llvm/torch-mlir/issues/1374
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from torch_mlir import torchscript
|
|||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
|
||||
# Wrap the bert model to avoid multiple returns problem
|
||||
class BertTinyWrapper(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
@ -257,9 +257,9 @@ class _FXGraphImporter:
|
|||
# FakeTensor's in case of a tuple return with multiple elements.
|
||||
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
||||
self._module = ir.Module.create(ir.Location.unknown())
|
||||
self._module.operation.attributes[
|
||||
"torch.debug_module_name"
|
||||
] = ir.StringAttr.get(func_name)
|
||||
self._module.operation.attributes["torch.debug_module_name"] = (
|
||||
ir.StringAttr.get(func_name)
|
||||
)
|
||||
function_type = _extract_function_type_from_graph(g)
|
||||
func = func_dialect.FuncOp(
|
||||
func_name,
|
||||
|
|
|
@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
(ns, unqual + "_", overload if not is_functional_op else "")
|
||||
),
|
||||
emitter_td,
|
||||
traits=["IsTrailingUnderscoreInplaceVariant"]
|
||||
if not is_functional_op
|
||||
else [],
|
||||
traits=(
|
||||
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
|
||||
),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
|
|
|
@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
|
|||
examples = []
|
||||
input_names = []
|
||||
dynamic_tensors = {}
|
||||
for (index, arg) in enumerate(inputs):
|
||||
for index, arg in enumerate(inputs):
|
||||
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
||||
shape = tuple(shape)
|
||||
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
||||
|
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
|
|||
input_names.append(input_name)
|
||||
|
||||
dynamic_dims = {}
|
||||
for (dimindex, dim) in enumerate(arg.shape):
|
||||
for dimindex, dim in enumerate(arg.shape):
|
||||
if dim < 0:
|
||||
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
||||
|
||||
|
|
|
@ -101,10 +101,12 @@ class RefBackendInvoker:
|
|||
def consume_return_funcs(*args):
|
||||
self.result = tuple(
|
||||
[
|
||||
arg
|
||||
if type in elemental_type_to_ctype
|
||||
else unranked_memref_to_numpy(
|
||||
arg, memref_type_to_np_dtype[type]
|
||||
(
|
||||
arg
|
||||
if type in elemental_type_to_ctype
|
||||
else unranked_memref_to_numpy(
|
||||
arg, memref_type_to_np_dtype[type]
|
||||
)
|
||||
)
|
||||
for arg, type in zip(args, ret_types)
|
||||
]
|
||||
|
|
|
@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: QuantizedReluInt32())
|
||||
def QuantizedReluInt32_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
|
||||
)
|
||||
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
||||
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
||||
class SliceScatterModule(torch.nn.Module):
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -18,6 +18,7 @@ mb = ModuleBuilder()
|
|||
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
|
||||
# naively duplicating a Tensor retains the identity of the TensorImpl.
|
||||
|
||||
|
||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.add3
|
||||
# Note that line-level debug information for parts unannotated in the Torch
|
||||
# graph are ascribed to the first op that carries source information. Presently
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: @__torch__.f
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.optional_return(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
|
||||
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
|
||||
|
|
|
@ -13,6 +13,7 @@ mb = ModuleBuilder()
|
|||
# else branch and making all defined values optional, so no special handling
|
||||
# is needed.
|
||||
|
||||
|
||||
# CHECK-LABEL: @__torch__.prim_If(
|
||||
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
||||
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
|
||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
|
||||
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
|
||||
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
|
||||
|
|
|
@ -15,6 +15,7 @@ import typing
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
||||
|
|
|
@ -13,6 +13,7 @@ from utils import create_script_function
|
|||
mb = ModuleBuilder()
|
||||
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.tuple(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK: @__torch__.returns_bool
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK: @__torch__.returns_none
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch._C import CompilationUnit
|
|||
|
||||
# RUN: %PYTHON %s
|
||||
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def create_script_function(func_name, ts_ir_str, **kwargs):
|
||||
cu = CompilationUnit()
|
||||
|
|
|
@ -1849,8 +1849,7 @@ def _emit_operation(
|
|||
|
||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||
# may have a different meaning.
|
||||
class EmptyType:
|
||||
...
|
||||
class EmptyType: ...
|
||||
|
||||
|
||||
Empty = EmptyType()
|
||||
|
|
|
@ -156,8 +156,7 @@ class GraphInfo:
|
|||
return ""
|
||||
|
||||
|
||||
class OnnxImportError(Exception):
|
||||
...
|
||||
class OnnxImportError(Exception): ...
|
||||
|
||||
|
||||
class NodeImporter:
|
||||
|
@ -235,22 +234,22 @@ class NodeImporter:
|
|||
else:
|
||||
default_opset_version = opset_import.version
|
||||
if default_opset_version:
|
||||
container_op.attributes[
|
||||
"torch.onnx_meta.opset_version"
|
||||
] = IntegerAttr.get(i64_type, default_opset_version)
|
||||
container_op.attributes["torch.onnx_meta.opset_version"] = (
|
||||
IntegerAttr.get(i64_type, default_opset_version)
|
||||
)
|
||||
if opset_versions:
|
||||
container_op.attributes[
|
||||
"torch.onnx_meta.opset_versions"
|
||||
] = DictAttr.get(opset_versions)
|
||||
container_op.attributes["torch.onnx_meta.opset_versions"] = (
|
||||
DictAttr.get(opset_versions)
|
||||
)
|
||||
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
|
||||
IntegerType.get_signed(64), m.ir_version
|
||||
)
|
||||
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
|
||||
m.producer_name
|
||||
)
|
||||
container_op.attributes[
|
||||
"torch.onnx_meta.producer_version"
|
||||
] = StringAttr.get(m.producer_version)
|
||||
container_op.attributes["torch.onnx_meta.producer_version"] = (
|
||||
StringAttr.get(m.producer_version)
|
||||
)
|
||||
|
||||
def import_all(self, func=True):
|
||||
"""Imports all nodes topologically."""
|
||||
|
@ -658,9 +657,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
|
|||
RankedTensorType.get(shape, IntegerType.get_signed(64)),
|
||||
IntegerAttr.get(
|
||||
IntegerType.get_signed(64),
|
||||
int.from_bytes(tp.raw_data, "little", signed=True)
|
||||
if tp.HasField("raw_data")
|
||||
else tp.int64_data[0],
|
||||
(
|
||||
int.from_bytes(tp.raw_data, "little", signed=True)
|
||||
if tp.HasField("raw_data")
|
||||
else tp.int64_data[0]
|
||||
),
|
||||
),
|
||||
),
|
||||
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
|
||||
|
@ -703,7 +704,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
|
|||
),
|
||||
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
|
||||
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
|
||||
)
|
||||
),
|
||||
# Intentionally unsupported: STRING
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue