diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72329026f..f2938e28e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: check-yaml - id: check-added-large-files - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 24.4.2 hooks: - id: black diff --git a/build_tools/scrape_releases.py b/build_tools/scrape_releases.py index 88f19d92b..77aa41c15 100644 --- a/build_tools/scrape_releases.py +++ b/build_tools/scrape_releases.py @@ -2,6 +2,7 @@ See https://github.com/llvm/torch-mlir/issues/1374 """ + import argparse import json diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index af2af2de3..840ec519d 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -3,6 +3,7 @@ from torch_mlir import torchscript from transformers import BertForMaskedLM + # Wrap the bert model to avoid multiple returns problem class BertTinyWrapper(torch.nn.Module): def __init__(self) -> None: diff --git a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py index fcea14dc1..81908d801 100644 --- a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py +++ b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py @@ -257,9 +257,9 @@ class _FXGraphImporter: # FakeTensor's in case of a tuple return with multiple elements. self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {} self._module = ir.Module.create(ir.Location.unknown()) - self._module.operation.attributes[ - "torch.debug_module_name" - ] = ir.StringAttr.get(func_name) + self._module.operation.attributes["torch.debug_module_name"] = ( + ir.StringAttr.get(func_name) + ) function_type = _extract_function_type_from_graph(g) func = func_dialect.FuncOp( func_name, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ca867723c..eea8d31a9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): (ns, unqual + "_", overload if not is_functional_op else "") ), emitter_td, - traits=["IsTrailingUnderscoreInplaceVariant"] - if not is_functional_op - else [], + traits=( + ["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [] + ), ) # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 6fa845ab3..7f630074e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -46,7 +46,7 @@ def convert_onnx(model, inputs): examples = [] input_names = [] dynamic_tensors = {} - for (index, arg) in enumerate(inputs): + for index, arg in enumerate(inputs): shape = map(lambda d: d if d >= 0 else 1, arg.shape) shape = tuple(shape) examples.append(torch.zeros(size=shape, dtype=arg.dtype)) @@ -55,7 +55,7 @@ def convert_onnx(model, inputs): input_names.append(input_name) dynamic_dims = {} - for (dimindex, dim) in enumerate(arg.shape): + for dimindex, dim in enumerate(arg.shape): if dim < 0: dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index a1611a1e5..1e958a4d9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -101,10 +101,12 @@ class RefBackendInvoker: def consume_return_funcs(*args): self.result = tuple( [ - arg - if type in elemental_type_to_ctype - else unranked_memref_to_numpy( - arg, memref_type_to_np_dtype[type] + ( + arg + if type in elemental_type_to_ctype + else unranked_memref_to_numpy( + arg, memref_type_to_np_dtype[type] + ) ) for arg, type in zip(args, ret_types) ] diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index d034e6d1f..8e2875842 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module): @register_test_case(module_factory=lambda: QuantizedReluInt32()) def QuantizedReluInt32_basic(module, tu: TestUtils): - module.forward( - tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32) - ) + module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)) # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index 07f064de7..be2a80d84 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils): # ============================================================================== + # For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1). # For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index). class SliceScatterModule(torch.nn.Module): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py index bd21c4e8b..5af1a6b89 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py @@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK: module attributes {torch.debug_module_name = "TestModule"} class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py index 4c323ec01..4c325308b 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py @@ -18,6 +18,7 @@ mb = ModuleBuilder() # `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so # naively duplicating a Tensor retains the identity of the TensorImpl. + # CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py index e33985fac..df6f1736c 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py @@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py index 1bc258a42..7e8df49a0 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.add3 # Note that line-level debug information for parts unannotated in the Torch # graph are ascribed to the first op that carries source information. Presently diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py index 5ee16e391..f3ee0a557 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK-LABEL: @__torch__.f @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py index 2acde08ca..f9505b91f 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py @@ -11,6 +11,7 @@ import typing mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.optional_return( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional { # CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/if.py b/projects/pt1/test/python/importer/jit_ir/node_import/if.py index 86390f707..02cb8d9f0 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/if.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/if.py @@ -13,6 +13,7 @@ mb = ModuleBuilder() # else branch and making all defined values optional, so no special handling # is needed. + # CHECK-LABEL: @__torch__.prim_If( # CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int { diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py index d432cd6ee..b28d63bb0 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py @@ -11,6 +11,7 @@ import typing mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.prim_Loop_forlike( # CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float { # CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py index 66959257e..759292b6d 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py @@ -15,6 +15,7 @@ import typing mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.prim_NumToTensor( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor { # CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py index a1f06c390..b6a313cd4 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py @@ -13,6 +13,7 @@ from utils import create_script_function mb = ModuleBuilder() NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])]) + # CHECK-LABEL: func.func @__torch__.tuple( # CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py index 0a27692fc..7cd4c3c16 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK: @__torch__.returns_bool @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py index 16a3359da..b0358467c 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK: @__torch__.returns_none @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/utils.py b/projects/pt1/test/python/importer/jit_ir/node_import/utils.py index 613ccb6a8..b06c38fdf 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/utils.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/utils.py @@ -9,6 +9,7 @@ from torch._C import CompilationUnit # RUN: %PYTHON %s + # Import TorchScript IR string as ScriptFunction. def create_script_function(func_name, ts_ir_str, **kwargs): cu = CompilationUnit() diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9acf4ad03..24bda3f5b 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1849,8 +1849,7 @@ def _emit_operation( # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: - ... +class EmptyType: ... Empty = EmptyType() diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index f1064f976..8d0e4cf5a 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -156,8 +156,7 @@ class GraphInfo: return "" -class OnnxImportError(Exception): - ... +class OnnxImportError(Exception): ... class NodeImporter: @@ -235,22 +234,22 @@ class NodeImporter: else: default_opset_version = opset_import.version if default_opset_version: - container_op.attributes[ - "torch.onnx_meta.opset_version" - ] = IntegerAttr.get(i64_type, default_opset_version) + container_op.attributes["torch.onnx_meta.opset_version"] = ( + IntegerAttr.get(i64_type, default_opset_version) + ) if opset_versions: - container_op.attributes[ - "torch.onnx_meta.opset_versions" - ] = DictAttr.get(opset_versions) + container_op.attributes["torch.onnx_meta.opset_versions"] = ( + DictAttr.get(opset_versions) + ) container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( IntegerType.get_signed(64), m.ir_version ) container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( m.producer_name ) - container_op.attributes[ - "torch.onnx_meta.producer_version" - ] = StringAttr.get(m.producer_version) + container_op.attributes["torch.onnx_meta.producer_version"] = ( + StringAttr.get(m.producer_version) + ) def import_all(self, func=True): """Imports all nodes topologically.""" @@ -658,9 +657,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = { RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get( IntegerType.get_signed(64), - int.from_bytes(tp.raw_data, "little", signed=True) - if tp.HasField("raw_data") - else tp.int64_data[0], + ( + int.from_bytes(tp.raw_data, "little", signed=True) + if tp.HasField("raw_data") + else tp.int64_data[0] + ), ), ), # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB @@ -703,7 +704,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = { ), onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False - ) + ), # Intentionally unsupported: STRING }