From 6877302504cdb4a775705286820b2aee872826cc Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 27 Apr 2024 14:16:31 -0700 Subject: [PATCH] [NFC reformat] Applies pre-commit formatting to Python files. (#3244) This is a large change because prior to this point, Python files in the project were not consistently formatted. This reformats them all with black defaults. Based on experience with prior projects, if you have a dev/long-term branch with Python patches, you can minimize merge conflicts prior to rebasing to include this commit by running `black` on your modified Python files, squashing, and then rebasing/merging. --- build_tools/autogen_ltc_backend.py | 20 +- build_tools/scrape_releases.py | 11 +- projects/pt1/e2e_testing/main.py | 131 +- projects/pt1/e2e_testing/xfail_sets.py | 546 ++- projects/pt1/examples/ltc_backend_bert.py | 67 +- projects/pt1/examples/ltc_backend_mnist.py | 5 +- projects/pt1/examples/torchdynamo_resnet18.py | 44 +- projects/pt1/examples/torchscript_resnet18.py | 30 +- .../torchscript_resnet18_all_output_types.py | 8 +- .../torchscript_stablehlo_backend_resnet.py | 6 +- .../torchscript_stablehlo_backend_tinybert.py | 11 +- projects/pt1/python/test/annotations-sugar.py | 15 +- .../test/compile_api/already_scripted.py | 4 +- .../python/test/compile_api/already_traced.py | 6 +- .../test/compile_api/backend_legal_ops.py | 13 +- projects/pt1/python/test/compile_api/basic.py | 19 +- .../pt1/python/test/compile_api/make_fx.py | 14 +- .../test/compile_api/output_type_spec.py | 9 +- .../pt1/python/test/compile_api/tracing.py | 17 +- .../pt1/python/test/debug/lockstep_basic.py | 7 +- .../python/test/dynamo_fx_importer/basic.py | 15 +- projects/pt1/python/test/lit.cfg.py | 51 +- .../python/test/torchscript_e2e_test/basic.py | 2 +- .../compilation_failure.py | 2 +- .../torchscript_e2e_test/error_reports.py | 8 +- .../torchscript_e2e_test/non_tensor_values.py | 2 +- .../torchscript_e2e_test/runtime_failure.py | 2 +- .../test/torchscript_e2e_test/submodule.py | 4 +- .../python/torch_mlir/_dynamo_fx_importer.py | 146 +- .../_torch_mlir_custom_op_example/__init__.py | 3 +- projects/pt1/python/torch_mlir/_version.py | 1 + .../reference_lazy_backend/gen_dummy_lib.py | 12 +- projects/pt1/python/torch_mlir/dynamo.py | 23 +- .../torch_mlir/jit_ir_importer/__init__.py | 6 +- .../build_tools/library_generator.py | 73 +- .../jit_ir_importer/build_tools/registry.py | 70 +- .../build_tools/testing_framework.py | 138 +- .../build_tools/torch_ods_gen.py | 669 ++-- .../jit_ir_importer/build_tools/utils.py | 2 + .../torchscript_annotations.py | 30 +- projects/pt1/python/torch_mlir/torchscript.py | 103 +- .../python/torch_mlir_e2e_test/annotations.py | 4 +- .../configs/fx_importer_backend.py | 47 +- .../configs/lazy_tensor_core.py | 11 +- .../configs/linalg_on_tensors_backend.py | 11 +- .../configs/native_torch.py | 6 +- .../configs/onnx_backend.py | 22 +- .../configs/torchdynamo.py | 65 +- .../configs/torchscript.py | 8 +- .../configs/tosa_backend.py | 11 +- .../torch_mlir_e2e_test/configs/utils.py | 1 + .../torch_mlir_e2e_test/debug/lockstep.py | 22 +- .../python/torch_mlir_e2e_test/framework.py | 103 +- .../linalg_on_tensors_backends/abc.py | 5 +- .../linalg_on_tensors_backends/refbackend.py | 178 +- .../torch_mlir_e2e_test/onnx_backends/abc.py | 5 +- .../onnx_backends/linalg_on_tensors.py | 25 +- .../python/torch_mlir_e2e_test/registry.py | 13 +- .../python/torch_mlir_e2e_test/reporting.py | 159 +- .../stablehlo_backends/linalg_on_tensors.py | 19 +- .../test_suite/__init__.py | 1 + .../torch_mlir_e2e_test/test_suite/arange.py | 261 +- .../test_suite/backprop.py | 206 +- .../torch_mlir_e2e_test/test_suite/basic.py | 3114 ++++++++------- .../torch_mlir_e2e_test/test_suite/cast.py | 71 +- .../test_suite/constant_alloc.py | 1184 +++--- .../test_suite/control_flow.py | 26 +- .../torch_mlir_e2e_test/test_suite/conv.py | 1020 ++--- .../test_suite/custom_op_example.py | 11 +- .../test_suite/diagonal.py | 78 +- .../test_suite/elementwise.py | 3330 ++++++++++------- .../test_suite/elementwise_comparison.py | 546 ++- .../test_suite/gridsampler.py | 131 +- .../histogram_binning_calibration.py | 56 +- .../test_suite/index_select.py | 127 +- .../torch_mlir_e2e_test/test_suite/matmul.py | 504 ++- .../torch_mlir_e2e_test/test_suite/mlp.py | 52 +- .../test_suite/nll_loss.py | 783 ++-- .../test_suite/norm_like.py | 464 ++- .../torch_mlir_e2e_test/test_suite/padding.py | 74 +- .../torch_mlir_e2e_test/test_suite/pooling.py | 1344 +++---- .../test_suite/quantized_models.py | 54 +- .../test_suite/reduction.py | 1261 +++++-- .../test_suite/reshape_like.py | 895 +++-- .../test_suite/return_types.py | 61 +- .../torch_mlir_e2e_test/test_suite/rng.py | 434 ++- .../torch_mlir_e2e_test/test_suite/scalar.py | 270 +- .../test_suite/scalar_comparison.py | 104 +- .../torch_mlir_e2e_test/test_suite/scatter.py | 1245 +++--- .../test_suite/slice_like.py | 621 +-- .../torch_mlir_e2e_test/test_suite/squeeze.py | 132 +- .../torch_mlir_e2e_test/test_suite/stats.py | 568 +-- .../test_suite/threshold.py | 210 +- .../test_suite/type_conversion.py | 116 +- .../test_suite/type_promotion.py | 76 +- .../test_suite/vision_models.py | 42 +- .../torch_mlir_e2e_test/tosa_backends/abc.py | 5 +- .../tosa_backends/linalg_on_tensors.py | 43 +- .../pt1/python/torch_mlir_e2e_test/utils.py | 4 +- projects/pt1/test/lit.cfg.py | 56 +- .../test/python/custom_op_shape_dtype_fn.py | 12 +- .../importer/jit_ir/get_registered_ops.py | 2 +- .../ivalue_import/annotations/arg-error.py | 5 +- .../annotations/arg-tensor-type-bound.py | 18 +- .../annotations/class-annotator-repr.py | 19 +- .../ivalue_import/annotations/export-error.py | 8 +- .../annotations/export-recursive.py | 10 +- .../ivalue_import/annotations/export.py | 9 +- .../jit_ir/ivalue_import/debug-module-name.py | 1 + .../importer/jit_ir/ivalue_import/dict.py | 6 +- .../functions-that-call-methods.py | 9 +- .../jit_ir/ivalue_import/functions.py | 4 + .../importer/jit_ir/ivalue_import/list.py | 3 + .../jit_ir/ivalue_import/methods-derefine.py | 27 +- .../jit_ir/ivalue_import/methods-locations.py | 15 +- .../importer/jit_ir/ivalue_import/methods.py | 11 +- .../object-identity-error-submodule.py | 4 +- .../ivalue_import/object-identity-error.py | 3 +- .../jit_ir/ivalue_import/object-identity.py | 3 +- .../importer/jit_ir/ivalue_import/prim.py | 3 + .../jit_ir/ivalue_import/primitives.py | 2 + .../jit_ir/ivalue_import/quantization.py | 8 +- .../importer/jit_ir/ivalue_import/strings.py | 3 + .../jit_ir/ivalue_import/submodules-select.py | 5 +- .../jit_ir/ivalue_import/submodules.py | 3 + .../ivalue_import/tensors-value-semantics.py | 4 +- .../importer/jit_ir/ivalue_import/tensors.py | 12 +- .../importer/jit_ir/ivalue_import/tuple.py | 3 + .../importer/jit_ir/node_import/classes.py | 3 + .../importer/jit_ir/node_import/debug-info.py | 9 +- .../importer/jit_ir/node_import/dict.py | 9 +- .../importer/jit_ir/node_import/elif.py | 1 + .../importer/jit_ir/node_import/errors.py | 20 +- .../function-block-arg-adjustment.py | 9 +- .../jit_ir/node_import/function-derefine.py | 2 + .../python/importer/jit_ir/node_import/if.py | 2 + .../importer/jit_ir/node_import/list.py | 6 +- .../importer/jit_ir/node_import/loop.py | 3 + .../importer/jit_ir/node_import/prim.py | 47 +- .../importer/jit_ir/node_import/tuple.py | 24 +- .../importer/jit_ir/node_import/types-bool.py | 1 + .../importer/jit_ir/node_import/types-none.py | 1 + .../jit_ir/node_import/unimplemented.py | 2 + .../importer/jit_ir/node_import/union.py | 6 +- projects/pt1/test/python/smoketest.py | 12 +- python/torch_mlir/compiler_utils.py | 35 +- python/torch_mlir/extras/fx_decomp_util.py | 1 + python/torch_mlir/extras/fx_importer.py | 4 +- python/torch_mlir/extras/onnx_importer.py | 75 +- .../torch_mlir/tools/import_onnx/__main__.py | 2 +- setup.py | 28 +- test/lit.cfg.py | 56 +- test/python/compile.py | 12 +- test/python/fx_importer/basic_test.py | 28 +- test/python/fx_importer/sparse_test.py | 2 +- .../onnx_importer/_torch_mlir_config.py | 2 + .../python/onnx_importer/command_line_test.py | 67 +- .../python/onnx_importer/import_smoke_test.py | 1 + utils/bazel/overlay_directories.py | 101 +- 159 files changed, 13557 insertions(+), 9866 deletions(-) diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 6f6fd5e89..13753a6d5 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -30,6 +30,7 @@ if not TORCH_INCLUDE_DIR.is_dir(): TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent + def reindent(text, prefix=""): return indent(dedent(text), prefix) @@ -75,7 +76,11 @@ class GenMlirLazyIr(torchgen.dest.GenLazyIR): ) # Only create this variable if it's used to avoid Wunused-variable - operand_idx_counter = "size_t i = 0;" if "i++" in (emplace_arguments_str + emplace_kwarguments) else "" + operand_idx_counter = ( + "size_t i = 0;" + if "i++" in (emplace_arguments_str + emplace_kwarguments) + else "" + ) return reindent( f""" @@ -111,12 +116,16 @@ class GenTorchMlirLTC: ) assert self.torch_ops_file.exists() self.binary_dir = Path(binary_dir) - assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}" + assert ( + self.binary_dir.is_dir() + ), f"Binary directory not found: {self.binary_dir}" self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml") self.backend_path = TORCH_MLIR_DIR.joinpath( "projects", "ltc", "csrc", "base_lazy_backend" ) - assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}" + assert ( + self.backend_path.is_dir() + ), f"Backend path not found: {self.backend_path}" self.generated_path = self.binary_dir.joinpath( "projects", "ltc", "csrc", "base_lazy_backend", "generated" ) @@ -168,8 +177,9 @@ class GenTorchMlirLTC: if ts_native_yaml_path.exists(): ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader) else: - logging.warning(f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}") - + logging.warning( + f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}" + ) parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path) self.native_functions = parsed_yaml.native_functions diff --git a/build_tools/scrape_releases.py b/build_tools/scrape_releases.py index b8c7265d3..88f19d92b 100644 --- a/build_tools/scrape_releases.py +++ b/build_tools/scrape_releases.py @@ -9,19 +9,20 @@ import requests # Parse arguments parser = argparse.ArgumentParser() -parser.add_argument('owner', type=str) -parser.add_argument('repo', type=str) +parser.add_argument("owner", type=str) +parser.add_argument("repo", type=str) args = parser.parse_args() # Get releases response = requests.get( - f"https://api.github.com/repos/{args.owner}/{args.repo}/releases") + f"https://api.github.com/repos/{args.owner}/{args.repo}/releases" +) body = json.loads(response.content) # Parse releases releases = [] for row in body: - for asset in row['assets']: + for asset in row["assets"]: releases.append((asset["name"], asset["browser_download_url"])) # Output HTML @@ -33,4 +34,4 @@ for name, url in releases: html += f" {name}
\n" html += """ """ -print(html) \ No newline at end of file +print(html) diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 5a61b50db..1ec7aa43f 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -25,10 +25,18 @@ from torch_mlir_e2e_test.configs import ( FxImporterTestConfig, ) -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend -from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import LinalgOnTensorsOnnxBackend -from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) +from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import ( + LinalgOnTensorsOnnxBackend, +) +from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import ( + LinalgOnTensorsTosaBackend, +) +from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import ( + LinalgOnTensorsStablehloBackend, +) from .xfail_sets import ( LINALG_XFAIL_SET, @@ -51,13 +59,28 @@ from .xfail_sets import ( # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests + register_all_tests() + def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", - "torchdynamo", "onnx", "fx_importer", "fx_importer_stablehlo"] + config_choices = [ + "native_torch", + "torchscript", + "linalg", + "stablehlo", + "make_fx_tosa", + "tosa", + "lazy_tensor_core", + "torchdynamo", + "onnx", + "fx_importer", + "fx_importer_stablehlo", + ] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") - parser.add_argument("-c", "--config", + parser.add_argument( + "-c", + "--config", choices=config_choices, default="linalg", help=f""" @@ -72,34 +95,52 @@ Meaning of options: "onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path. "fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors. "fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend. -""") - parser.add_argument("-f", "--filter", default=".*", help=""" +""", + ) + parser.add_argument( + "-f", + "--filter", + default=".*", + help=""" Regular expression specifying which tests to include in this run. -""") - parser.add_argument("-v", "--verbose", - default=False, - action="store_true", - help="report test results with additional detail") - parser.add_argument("-s", "--sequential", - default=False, - action="store_true", - help="""Run tests sequentially rather than in parallel. +""", + ) + parser.add_argument( + "-v", + "--verbose", + default=False, + action="store_true", + help="report test results with additional detail", + ) + parser.add_argument( + "-s", + "--sequential", + default=False, + action="store_true", + help="""Run tests sequentially rather than in parallel. This can be useful for debugging, since it runs the tests in the same process, -which make it easier to attach a debugger or get a stack trace.""") - parser.add_argument("--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed", - metavar="TEST", type=str, nargs="+", - help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.") - parser.add_argument("--ignore_failures", - default=False, - action="store_true", - help="return exit code 0 even if the test fails to unblock pipeline") +which make it easier to attach a debugger or get a stack trace.""", + ) + parser.add_argument( + "--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed", + metavar="TEST", + type=str, + nargs="+", + help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.", + ) + parser.add_argument( + "--ignore_failures", + default=False, + action="store_true", + help="return exit code 0 even if the test fails to unblock pipeline", + ) return parser + def main(): args = _get_argparse().parse_args() - all_test_unique_names = set( - test.unique_name for test in GLOBAL_TEST_REGISTRY) + all_test_unique_names = set(test.unique_name for test in GLOBAL_TEST_REGISTRY) # Find the selected config. if args.config == "linalg": @@ -147,23 +188,26 @@ def main(): xfail_set = ONNX_XFAIL_SET crashing_set = ONNX_CRASHING_SET - do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set) - available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt] + do_not_attempt = set( + args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or [] + ).union(crashing_set) + available_tests = [ + test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt + ] if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None: for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed: if arg not in all_test_unique_names: - print(f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name") + print( + f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name" + ) sys.exit(1) # Find the selected tests, and emit a diagnostic if none are found. tests = [ - test for test in available_tests - if re.match(args.filter, test.unique_name) + test for test in available_tests if re.match(args.filter, test.unique_name) ] if len(tests) == 0: - print( - f"ERROR: the provided filter {args.filter!r} does not match any tests" - ) + print(f"ERROR: the provided filter {args.filter!r} does not match any tests") print("The available tests are:") for test in available_tests: print(test.unique_name) @@ -175,18 +219,25 @@ def main(): # Report the test results. failed = report_results(results, xfail_set, args.verbose, args.config) if args.config == "torchdynamo": - print("\033[91mWarning: the TorchScript based dynamo support is deprecated. " - "The config for torchdynamo is planned to be removed in the future.\033[0m") + print( + "\033[91mWarning: the TorchScript based dynamo support is deprecated. " + "The config for torchdynamo is planned to be removed in the future.\033[0m" + ) if args.ignore_failures: sys.exit(0) sys.exit(1 if failed else 0) + def _suppress_warnings(): import warnings + # Ignore warning due to Python bug: # https://stackoverflow.com/questions/4964101/pep-3118-warning-when-using-ctypes-array-as-numpy-array - warnings.filterwarnings("ignore", - message="A builtin ctypes object gave a PEP3118 format string that does not match its itemsize") + warnings.filterwarnings( + "ignore", + message="A builtin ctypes object gave a PEP3118 format string that does not match its itemsize", + ) + if __name__ == "__main__": _suppress_warnings() diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 62c6cd777..1b5bc1c94 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -31,21 +31,17 @@ LINALG_CRASHING_SET = { TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors - # torch._dynamo.exc.Unsupported: Tensor.item "CumsumModule_basic", - # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2D_basic", - # Size result mismatch (exposed by downstream canonicalizer # on incompatabile casts). # https://github.com/pytorch/pytorch/issues/119407 "ConvolutionBackwardModule2DStrided_basic", - # RuntimeError: Index tensor must have the same number of dimensions as self tensor # RuntimeError: Failed running call_function aten.nll_loss_backward(... # https://github.com/pytorch/pytorch/issues/89630 @@ -59,196 +55,159 @@ TORCHDYNAMO_XFAIL_SET = { # RuntimeError: Failed running call_function aten.uniform(... # https://github.com/pytorch/torchdynamo/issues/1954 "UniformNoCorrelationModule_basic", - #### Torch-MLIR internal compiler errors - # These are probably due to slightly different ops being recorded by # torchdynamo vs. torchscript. - # No upstream decompositions. # %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) # See also: https://github.com/pytorch/torchdynamo/issues/327 "AtenEmbeddingBagSumExample_basic", - # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", # error: torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: AssertionError: Unregistered operation: torch.aten._prelu_kernel "ElementwisePreluStaticModule_basic", - - #ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777) + # ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777) "UpSampleNearest2dDynamicFactor_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - #ERROR: value (-56) is not equal to golden value (200) + # ERROR: value (-56) is not equal to golden value (200) "AtenIntTensorByteDtypeModule_basic", # ERROR: assert isinstance(e, FakeTensor) "ElementwiseAddScalar_NumToTensorFloat_Module_basic", # ERROR: assert isinstance(e, FakeTensor) "RsubInt0d_NumToTensor_Module_basic", - # ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::squeeze.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. "PrimsSqueezeModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "SplitDimStaticModule_basic", "SplitDimDynamicModule_basic", - # ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::view_of.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", - # See https://github.com/llvm/torch-mlir/pull/2040 and corresponding upstream issue # https://github.com/pytorch/pytorch/issues/99752. # torch._dynamo.exc.Unsupported: call_function BuiltinVariable(bool) [TensorVariable()] {} - 'TensorToBoolZeroRank_basic', - 'TensorToBool_basic', - + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} - 'AtenSubFloatModule_basic', - 'AtenMulFloatModule_basic', - 'BoolFloatFalseModule_basic', - 'BoolFloatTrueModule_basic', - 'CeilFloatModule_basic', - 'DivFloatModule_basic', - 'GeFloatIntModule_basic', - 'GeFloatModule_basic', - 'GtFloatIntModule_basic', - 'NeFloatIntModule_basic', - 'SubFloatModule_basic', - 'MulFloatModule_basic', - 'TensorToFloatZeroRank_basic', - 'TensorToFloat_basic', + "AtenSubFloatModule_basic", + "AtenMulFloatModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "CeilFloatModule_basic", + "DivFloatModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GtFloatIntModule_basic", + "NeFloatIntModule_basic", + "SubFloatModule_basic", + "MulFloatModule_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} - - # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} - 'AddIntModule_basic', - 'AtenIntTensorCharDtypeModule_basic', - 'BoolIntFalseModule_basic', - 'BoolIntTrueModule_basic', - 'DivIntModule_basic', - 'EqIntModule_basic', - 'GeIntModule_basic', - 'GtIntModule_basic', - 'MulIntModule_basic', - 'NeIntModule_basic', - 'SqrtIntModule_basic', - 'SubIntModule_basic', - 'TensorToIntZeroRank_basic', - 'TensorToInt_basic', - 'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic', - 'ViewCollapseDynamicWithAtenSizeIntModule_basic', + "AddIntModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "DivIntModule_basic", + "EqIntModule_basic", + "GeIntModule_basic", + "GtIntModule_basic", + "MulIntModule_basic", + "NeIntModule_basic", + "SqrtIntModule_basic", + "SubIntModule_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} - # ERROR: torch._dynamo.exc.Unsupported: Tensor.item - 'AtenItemIntOpModule_basic', - 'AtenItemFpOpModule_basic', - + "AtenItemIntOpModule_basic", + "AtenItemFpOpModule_basic", # ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)} - 'SortIntListReverse_basic', - + "SortIntListReverse_basic", # ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {} - 'SortIntList_basic', - + "SortIntList_basic", # START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default - 'AtenFloatScalarModule_basic', - 'AtenIntBoolOpModule_basic', - 'QuantizedMLP_basic', - 'QuantizedSingleLayer_basic', - 'QuantizedBatchedInputSingleLayer_basic', - 'QuantizedNoLayer_basic', - 'ScalarImplicitFloatModule_basic', - 'ScalarImplicitIntModule_basic', + "AtenFloatScalarModule_basic", + "AtenIntBoolOpModule_basic", + "QuantizedMLP_basic", + "QuantizedSingleLayer_basic", + "QuantizedBatchedInputSingleLayer_basic", + "QuantizedNoLayer_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default - # START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default - 'BincountMinlengthModule_basic', - 'BincountModule_basic', - 'BincountStaticSizeModule_basic', + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default - # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool - 'BoolFloatConstantModule_basic', - 'BoolIntConstantModule_basic', - + "BoolFloatConstantModule_basic", + "BoolIntConstantModule_basic", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size "ViewSizeFromOtherTensor_basic", - # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__ - 'ContainsIntList_False', - 'ContainsIntList_True', - + "ContainsIntList_False", + "ContainsIntList_True", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all - 'AllBoolFalseModule_basic', - 'AllBoolTrueModule_basic', - + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any - 'AnyBoolFalseModule_basic', - 'AnyBoolTrueModule_basic', - + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt - 'SqrtIntConstantModule_basic', - + "SqrtIntConstantModule_basic", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size - 'BroadcastDynamicDimModule_basic', - + "BroadcastDynamicDimModule_basic", # START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int - 'AtenIntBoolOpConstFalseModule_basic', - 'AtenIntBoolOpConstTrueModule_basic', - 'IntFloatModule_basic', - 'PowIntFloatModule_basic', + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "IntFloatModule_basic", + "PowIntFloatModule_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int - # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len - 'LenStrModule_basic', - + "LenStrModule_basic", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel - 'NumelModule_basic', - 'NumelZeroRankModule_basic', - + "NumelModule_basic", + "NumelZeroRankModule_basic", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max - 'PrimMaxIntModule_basic', - + "PrimMaxIntModule_basic", # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min - 'PrimMinIntModule_basic', - 'PrimMinIntDynamicModule_basic', - + "PrimMinIntModule_basic", + "PrimMinIntDynamicModule_basic", # START tests failing due to: empty graph in dynamo - 'IsFloatingPointFloat_True', - 'IsFloatingPointInt_False', - 'TorchPrimLoopForLikeModule_basic', - 'TorchPrimLoopWhileLikeModule_basic', + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", "ScalarConstantTupleModule_basic", # END tests failing due to: empty graph in dynamo - # ERROR due to: backend never runs because of empty frame - 'ConstantBoolParameterModule_basic', - + "ConstantBoolParameterModule_basic", # START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "ElementwiseAddScalarFloatModule_basic", # END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "HBC_basic", - # ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "ElementwiseDivScalarModule_basic", - # ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' "ElementwiseAtenDivIntScalarModule_basic", - # ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", - # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode "ElementwiseAtenFloorDivideScalarNegativeModule_basic", "ElementwiseAtenFloorDivideScalarModule_basic", @@ -258,57 +217,43 @@ TORCHDYNAMO_XFAIL_SET = { "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", "ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic", "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", - # ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' "AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool2dDynamic_basic", "AdaptiveAvgPool2dDynamicNoBatch_basic", - # ERROR: Exception: Unsupported op: get_attr "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", - # START tests failing due to: complex floating point ops # END tests failing due to: complex floating point ops - # ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", - # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} "ScatterValueFloatModule_basic", # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} "ScatterValueIntModule_basic", - # AssertionError: Unregistered operation: torch.aten._unsafe_index_put "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", - # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - # AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu "ScaledDotProductAttentionDifferentModule_basic", - # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only "AtenEmbeddingBagStaticModule_basic", - # Lowering not present for this case "ElementwiseToDtypeI64ToUI8Module_basic", - # torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8! "ElementwiseAddScalarInt8Module_basic", - # ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32) "ThresholdBackward2dMixedModule_basic", - # ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4])) "ArangeStartOutViewModule_basic", - # Dynamo does not support tracing quantized tensors "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", @@ -327,13 +272,10 @@ TORCHDYNAMO_XFAIL_SET = { "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", - # Dynamo not supporting conv_tbc "ConvTbcModule_basic", - "FloatImplicitModule_basic", "IntImplicitModule_basic", - # Others "ExponentialModule_basic", "GridSamplerBasic1_basic", @@ -383,142 +325,141 @@ TORCHDYNAMO_CRASHING_SET = { "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticModule_basic", - # Looks like incorrect fx graph conversion "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", } FX_IMPORTER_XFAIL_SET = { - 'AllBoolFalseModule_basic', - 'AllBoolTrueModule_basic', - 'AnyBoolFalseModule_basic', - 'AnyBoolTrueModule_basic', - 'ArangeStartOutViewModule_basic', - 'AtenEmbeddingBagStaticModule_basic', - 'AtenEmbeddingBagSumExample_basic', - 'AtenFloatScalarModule_basic', - 'AtenIntBoolOpConstFalseModule_basic', - 'AtenIntBoolOpConstTrueModule_basic', - 'AtenIntBoolOpModule_basic', - 'AtenItemFpOpModule_basic', - 'AtenMatmulQMixedSigni8Transpose_basic', - 'AtenMatmulQMixedSigni8_basic', - 'AtenMatmulQint8MV_basic', - 'AtenMatmulQint8_basic', - 'AtenMatmulQint8VM_basic', - 'AtenMatmulQint8VV_basic', - 'AtenMmQMixedSigni8_basic', - 'AtenMmQint8_basic', - 'AtenMmQuint8_basic', + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeStartOutViewModule_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenItemFpOpModule_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8_basic", + "AtenMatmulQint8VM_basic", + "AtenMatmulQint8VV_basic", + "AtenMmQMixedSigni8_basic", + "AtenMmQint8_basic", + "AtenMmQuint8_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", - 'AtenSubFloatModule_basic', - 'BincountMinlengthModule_basic', - 'BincountModule_basic', - 'BincountStaticSizeModule_basic', - 'BoolFloatConstantModule_basic', - 'BoolFloatFalseModule_basic', - 'BoolFloatTrueModule_basic', - 'BoolIntConstantModule_basic', - 'BoolIntFalseModule_basic', - 'BoolIntTrueModule_basic', - 'BroadcastDynamicDimModule_basic', - 'CeilFloatModule_basic', - 'ConstantBoolParameterModule_basic', - 'ContainsIntList_False', - 'ContainsIntList_True', - 'Conv2dQInt8Module_basic', - 'Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier', - 'ConvTbcModule_basic', - 'ConvolutionBackwardModule2DPadded_basic', - 'ConvolutionBackwardModule2DStrided_basic', - 'ConvolutionBackwardModule2D_basic', - 'CumsumModule_basic', - 'DivFloatModule_basic', - 'DivIntModule_basic', - 'ElementwiseAddScalar_NumToTensorFloat_Module_basic', - 'ElementwiseDequantizePerChannelModule_basic', - 'ElementwiseDequantizePerTensorModule_basic', - 'ElementwiseQuantizePerTensorModule_basic', - 'ElementwiseQuantizePerTensorUIntModule_basic', - 'ElementwiseToDtypeI64ToUI8Module_basic', - 'EqIntModule_basic', - 'FakeQuantizePerTensorAffineDynamicShapeModule_basic', - 'FakeQuantizePerTensorAffineModule_basic', - 'FakeQuantizePerTensorAffineRoundToEvenModule_basic', - 'FloatImplicitModule_basic', - 'GeFloatIntModule_basic', - 'GeFloatModule_basic', - 'GeIntModule_basic', - 'GtFloatIntModule_basic', - 'GtIntModule_basic', - 'IntFloatModule_basic', - 'IntImplicitModule_basic', - 'IsFloatingPointFloat_True', - 'IsFloatingPointInt_False', - 'LenStrModule_basic', - 'MaxPool3dCeilModeTrueModule_basic', - 'MaxPool3dEmptyStrideStaticModule_basic', - 'MaxPool3dLargeDatadModule_basic', - 'MaxPool3dModuleRandomSimple_basic', - 'MaxPool3dModule_basic', - 'MaxPool3dStaticCeilModeTrueModule_basic', - 'MaxPool3dStaticModule_basic', - 'MulFloatModule_basic', - 'NativeGroupNormBackwardModule_basic', - 'NeFloatIntModule_basic', - 'NeIntModule_basic', - 'NllLossModuleBackward1DMeanWeight_basic', - 'NllLossModuleBackward1DMean_basic', - 'NllLossModuleBackward1DSumWeight_basic', - 'NllLossModuleBackward1DSum_basic', - 'NllLossModuleBackward1DWeight_basic', - 'NllLossModuleBackward1D_basic', - 'NumToTensorFloatModule_basic', - 'NumToTensorIntModule_basic', - 'NumelModule_basic', - 'NumelZeroRankModule_basic', - 'PowIntFloatModule_basic', - 'PrimMaxIntModule_basic', - 'PrimMinIntDynamicModule_basic', - 'PrimMinIntModule_basic', - 'PrimsSqueezeEmptyDimensionsModule_basic', - 'PrimsSqueezeModule_basic', - 'PrimsViewOfModule_basic', - 'PrimsViewOfZeroRankModule_basic', - 'QuantizedBatchedInputSingleLayer_basic', - 'QuantizedMLP_basic', - 'QuantizedNoLayer_basic', - 'QuantizedSingleLayer_basic', - 'ReduceMaxAlongDimUnsignedInt_basic', - 'ReduceMinAlongDimUnsignedInt_basic', - 'RsubInt0d_NumToTensor_Module_basic', - 'ScalarConstantTupleModule_basic', - 'ScalarImplicitFloatModule_basic', - 'SortIntListReverse_basic', - 'SortIntList_basic', - 'SplitDimDynamicModule_basic', - 'SplitDimStaticModule_basic', - 'SqrtIntConstantModule_basic', - 'SqrtIntModule_basic', - 'SubFloatModule_basic', - 'TModuleRank0_basic', - 'TensorToBoolZeroRank_basic', - 'TensorToBool_basic', - 'TensorToFloatZeroRank_basic', - 'TensorToFloat_basic', - 'TestMultipleTensorAndPrimitiveTypesReturn_basic', - 'ThresholdBackward2dMixedModule_basic', - 'TorchPrimLoopForLikeModule_basic', - 'TorchPrimLoopWhileLikeModule_basic', - 'UnbindIntGetItem_Module_basic', - 'UnbindIntListUnpack_Module_basic', - 'UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic', - 'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic', - 'UpSampleNearest2dDynamicFactor_basic', - 'ViewCollapseDynamicWithAtenSizeIntModule_basic', - 'ViewSizeFromOtherTensor_basic', + "AtenSubFloatModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BroadcastDynamicDimModule_basic", + "CeilFloatModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv2dQInt8Module_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "ConvTbcModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "CumsumModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "EqIntModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "FloatImplicitModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "IntFloatModule_basic", + "IntImplicitModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "LenStrModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dStaticModule_basic", + "MulFloatModule_basic", + "NativeGroupNormBackwardModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "QuantizedBatchedInputSingleLayer_basic", + "QuantizedMLP_basic", + "QuantizedNoLayer_basic", + "QuantizedSingleLayer_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SubFloatModule_basic", + "TModuleRank0_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "ThresholdBackward2dMixedModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UpSampleNearest2dDynamicFactor_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewSizeFromOtherTensor_basic", } FX_IMPORTER_CRASHING_SET = { @@ -1492,7 +1433,7 @@ STABLEHLO_PASS_SET = { "ElementwiseTruncModule_basic", } -STABLEHLO_CRASHING_SET = { +STABLEHLO_CRASHING_SET = { "AtenEmbeddingBagSumExample_basic", } @@ -1941,43 +1882,43 @@ TOSA_PASS_SET = { "LinspaceModule_basic", "LinspaceOneSizeModule_basic", "LinspaceTwoSizeModule_basic", - "TorchPrimLoopForLikeTensorArgModule_basic" + "TorchPrimLoopForLikeTensorArgModule_basic", } -MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { -### Tests additionally passing in make_fx_tosa - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dStaticEvenMultiple_basic", - "CosineSimilarityModule_basic", - "NativeGroupNormBackwardModule_basic", - "ReduceFrobeniusNormKeepDimModule_basic", - "ReduceFrobeniusNormModule_basic", - "SliceWholeTensorModule_basic", - "TensorFloatModule_basic", - "TensorIntModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "RepeatInterleaveSelfIntModule_basic", - "TorchPrimLoopForLikeTensorArgModule_basic", - "ViewSizeDimFollowedByCollapsedOnesModule_basic", - "ViewSizeDimFollowedByExpandedOnesModule_basic", - "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", - "ViewSizeDimLedByCollapsedOnesModule_basic", - "ViewSizeFromOtherTensor_basic", -}) - { -### Test failing in make_fx_tosa but not in tosa - +MAKE_FX_TOSA_PASS_SET = ( + TOSA_PASS_SET + | { + ### Tests additionally passing in make_fx_tosa + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "CosineSimilarityModule_basic", + "NativeGroupNormBackwardModule_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceFrobeniusNormModule_basic", + "SliceWholeTensorModule_basic", + "TensorFloatModule_basic", + "TensorIntModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "RepeatInterleaveSelfIntModule_basic", + "TorchPrimLoopForLikeTensorArgModule_basic", + "ViewSizeDimFollowedByCollapsedOnesModule_basic", + "ViewSizeDimFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", + "ViewSizeDimLedByCollapsedOnesModule_basic", + "ViewSizeFromOtherTensor_basic", + } +) - { + ### Test failing in make_fx_tosa but not in tosa # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", "MatmulStaticBroadcast_basic", - # failed to legalize operation 'torch.aten.max_pool2d_with_indices "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", "MaxPool2dStaticModule_basic", "ResNet18StaticModule_basic", - # Unimplemented operator 'aten._index_put_impl_.hacked_twin' "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", @@ -1986,18 +1927,14 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { # failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal "AtenEyeModuleInt2D_basic", "AtenEyeMModuleInt2D_basic", - "Conv2dBiasNoPaddingModule_basic", "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", - "AtenInstanceNormModule_basic", - # failed to legalize operation 'torch.operator' "ElementwisePreluModule_basic", - "ElementwisePreluStaticModule_basic", - + "ElementwisePreluStaticModule_basic", # Shape Related failures "PrimListUnpackNumMismatchModule_basic", "ReshapeExpandModule_basic", @@ -2019,8 +1956,7 @@ LTC_CRASHING_SET = { } LTC_XFAIL_SET = { - "TorchPrimLoopForLikeTensorArgModule_basic" - "CollapseAllDimensionsModule_basic", + "TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic", "CollapseRank1DynamicModule_basic", "CollapseStaticModule_basic", "CollapsePartialDynamicModule_basic", @@ -2162,7 +2098,6 @@ LTC_XFAIL_SET = { ONNX_XFAIL_SET = { # Failure - cast error "PermuteNegativeIndexModule_basic", - # Failure - expand multiple dynamic dims "EmbeddingModuleF16_basic", "EmbeddingModuleI32_basic", @@ -2174,7 +2109,6 @@ ONNX_XFAIL_SET = { "IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", "IndexTensorSelectDimModule_basic", - # Failure - incorrect numerics "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", @@ -2211,14 +2145,12 @@ ONNX_XFAIL_SET = { "StdCorrectionLargeInputModule_basic", "TupleModule_basic", "VarCorrectionLargeInputModule_basic", - # Failure - incorrect shape "ArangeStartOutDtypeModule_basic", "ArangeStartOutViewModule_basic", "MoveDimIntNegativeIndexModule_basic", "ReduceL3NormKeepDimModule_basic", "ViewSizeFromOtherTensor_basic", - # Failure - onnx_export "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", @@ -2619,10 +2551,8 @@ ONNX_XFAIL_SET = { "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", - # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", - # Failure - onnx_lowering: onnx.If "DiagonalModule_basic", "DiagonalModule_nonsquare", @@ -2633,12 +2563,10 @@ ONNX_XFAIL_SET = { "DiagonalModule_with_offset", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", - # Failure - onnx_lowering: onnx.MaxPool "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdFloatModule_basic", "ReduceProdDtypeFloatModule_basic", @@ -2646,7 +2574,6 @@ ONNX_XFAIL_SET = { "ReduceProdUnsignedIntModule_basic", "ReduceProdSignedIntModule_basic", "ReduceProdDtypeIntModule_basic", - # ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64) "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", @@ -2656,21 +2583,17 @@ ONNX_XFAIL_SET = { "BernoulliFloatModule_basic", "BernoulliPModule_basic", "BernoulliTensorModule_basic", - # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdDimIntFloatModule_basic", - # Failure - onnx_lowering: onnx.Resize "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticSize_basic", - # Failure - onnx_lowering: onnx.ScatterElements "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", "ScatterValueFloatModule_basic", - # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", @@ -2696,14 +2619,11 @@ ONNX_XFAIL_SET = { "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic", "IndexPutHackedTwin3DIntNonAccumulateModule_basic", - # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", - # Failure - unknown "BernoulliModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", @@ -2741,14 +2661,14 @@ if torch_version_for_comparison() >= version.parse("2.4.0.dev"): "ReduceL1NormWithDTypeModule_basic", } -if torch_version_for_comparison() < version.parse('2.3.0.dev'): +if torch_version_for_comparison() < version.parse("2.3.0.dev"): ONNX_XFAIL_SET = ONNX_XFAIL_SET | { # ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120])) "RepeatInterleaveSelfIntNoDimModule_basic", } -ONNX_CRASHING_SET = { +ONNX_CRASHING_SET = { "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", "ElementwisePreluModule_basic", diff --git a/projects/pt1/examples/ltc_backend_bert.py b/projects/pt1/examples/ltc_backend_bert.py index 048c74233..0b4e6fc7f 100644 --- a/projects/pt1/examples/ltc_backend_bert.py +++ b/projects/pt1/examples/ltc_backend_bert.py @@ -23,34 +23,43 @@ import torch._lazy from datasets import load_dataset from datasets.dataset_dict import DatasetDict from torch.utils.data import DataLoader -from transformers import BertForSequenceClassification, \ - BertConfig, BertTokenizer, AdamW, get_scheduler +from transformers import ( + BertForSequenceClassification, + BertConfig, + BertTokenizer, + AdamW, + get_scheduler, +) def tokenize_dataset(dataset: DatasetDict) -> DatasetDict: - tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") def tokenize_function(examples): - return tokenizer(examples["text"], padding="max_length", - truncation=True) + return tokenizer(examples["text"], padding="max_length", truncation=True) tokenized_datasets = dataset.map(tokenize_function, batched=True) - tokenized_datasets = tokenized_datasets.remove_columns(['text']) - tokenized_datasets = tokenized_datasets.rename_column('label', 'labels') - tokenized_datasets.set_format('torch') + tokenized_datasets = tokenized_datasets.remove_columns(["text"]) + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + tokenized_datasets.set_format("torch") return tokenized_datasets -def train(model: BertForSequenceClassification, - num_epochs: int, - num_training_steps: int, - train_dataloader: DataLoader, - device: torch.device) -> List[torch.Tensor]: +def train( + model: BertForSequenceClassification, + num_epochs: int, + num_training_steps: int, + train_dataloader: DataLoader, + device: torch.device, +) -> List[torch.Tensor]: optimizer = AdamW(model.parameters(), lr=5e-5) - lr_scheduler = get_scheduler('linear', optimizer=optimizer, - num_warmup_steps=0, - num_training_steps=num_training_steps) + lr_scheduler = get_scheduler( + "linear", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_training_steps, + ) model.train() losses = [] @@ -66,14 +75,14 @@ def train(model: BertForSequenceClassification, lr_scheduler.step() optimizer.zero_grad() - if 'lazy' in str(model.device): + if "lazy" in str(model.device): print("Calling Mark Step") torch._lazy.mark_step() return losses -def main(device='lazy', full_size=False): +def main(device="lazy", full_size=False): """ Load model to specified device. Ensure that any backends have been initialized by this point. @@ -82,15 +91,14 @@ def main(device='lazy', full_size=False): """ torch.manual_seed(0) - tokenized_datasets = tokenize_dataset(load_dataset('imdb')) - small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \ - .select(range(2)) + tokenized_datasets = tokenize_dataset(load_dataset("imdb")) + small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2)) - train_dataloader = DataLoader(small_train_dataset, shuffle=True, - batch_size=8) + train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8) if full_size: - model = BertForSequenceClassification.from_pretrained('bert-base-cased', - num_labels=2) + model = BertForSequenceClassification.from_pretrained( + "bert-base-cased", num_labels=2 + ) else: configuration = BertConfig( vocab_size=28996, @@ -98,7 +106,7 @@ def main(device='lazy', full_size=False): num_hidden_layers=1, num_attention_heads=2, intermediate_size=32, - hidden_act='gelu', + hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, @@ -113,12 +121,12 @@ def main(device='lazy', full_size=False): losses = train(model, num_epochs, num_training_steps, train_dataloader, device) # Get debug information from LTC - if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules: + if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules: computation = lazy_backend.get_latest_computation() if computation: print(computation.debug_string()) - print('Loss: ', losses) + print("Loss: ", losses) return model, losses @@ -136,7 +144,7 @@ if __name__ == "__main__": parser.add_argument( "-f", "--full_size", - action='store_true', + action="store_true", default=False, help="Use full sized BERT model instead of one with smaller parameterization", ) @@ -145,6 +153,7 @@ if __name__ == "__main__": if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": import torch._lazy.ts_backend + torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": diff --git a/projects/pt1/examples/ltc_backend_mnist.py b/projects/pt1/examples/ltc_backend_mnist.py index bdc9edd09..ed71ddd17 100644 --- a/projects/pt1/examples/ltc_backend_mnist.py +++ b/projects/pt1/examples/ltc_backend_mnist.py @@ -13,7 +13,7 @@ import torch._lazy import torch.nn.functional as F -def main(device='lazy'): +def main(device="lazy"): """ Load model to specified device. Ensure that any backends have been initialized by this point. @@ -65,7 +65,7 @@ def main(device='lazy'): torch._lazy.mark_step() # Get debug information from LTC - if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules: + if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules: computation = lazy_backend.get_latest_computation() if computation: print(computation.debug_string()) @@ -90,6 +90,7 @@ if __name__ == "__main__": if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": import torch._lazy.ts_backend + torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": diff --git a/projects/pt1/examples/torchdynamo_resnet18.py b/projects/pt1/examples/torchdynamo_resnet18.py index 377c632da..76602d4ba 100644 --- a/projects/pt1/examples/torchdynamo_resnet18.py +++ b/projects/pt1/examples/torchdynamo_resnet18.py @@ -21,19 +21,18 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend def load_and_preprocess_image(url: str): headers = { - 'User-Agent': - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36' + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" } - img = Image.open(requests.get(url, headers=headers, - stream=True).raw).convert("RGB") + img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") # preprocessing pipeline - preprocess = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) img_preprocessed = preprocess(img) return torch.unsqueeze(img_preprocessed, 0) @@ -62,17 +61,23 @@ def predictions(torch_func, jit_func, img, labels): print("torch-mlir prediction") print(prediction) -image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" + +image_url = ( + "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" +) print("load image from " + image_url, file=sys.stderr) img = load_and_preprocess_image(image_url) labels = load_labels() + @make_simple_dynamo_backend -def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): +def refbackend_torchdynamo_backend( + fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor] +): mlir_module = torchscript.compile( - fx_graph, example_inputs, output_type="linalg-on-tensors") + fx_graph, example_inputs, output_type="linalg-on-tensors" + ) backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(mlir_module) loaded = backend.load(compiled) @@ -85,10 +90,17 @@ def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, else: result = tuple(torch.from_numpy(x) for x in result) return result + return compiled_callable + resnet18 = models.resnet18(pretrained=True) resnet18.train(False) dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18) -predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), img, labels) +predictions( + resnet18.forward, + lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), + img, + labels, +) diff --git a/projects/pt1/examples/torchscript_resnet18.py b/projects/pt1/examples/torchscript_resnet18.py index 62e5eda6c..0cc5b5dda 100644 --- a/projects/pt1/examples/torchscript_resnet18.py +++ b/projects/pt1/examples/torchscript_resnet18.py @@ -18,19 +18,18 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend def load_and_preprocess_image(url: str): headers = { - 'User-Agent': - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36' + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" } - img = Image.open(requests.get(url, headers=headers, - stream=True).raw).convert("RGB") + img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") # preprocessing pipeline - preprocess = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) img_preprocessed = preprocess(img) return torch.unsqueeze(img_preprocessed, 0) @@ -59,7 +58,10 @@ def predictions(torch_func, jit_func, img, labels): print("torch-mlir prediction") print(prediction) -image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" + +image_url = ( + "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" +) print("load image from " + image_url, file=sys.stderr) img = load_and_preprocess_image(image_url) @@ -67,7 +69,9 @@ labels = load_labels() resnet18 = models.resnet18(pretrained=True) resnet18.train(False) -module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") +module = torchscript.compile( + resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors" +) backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(module) jit_module = backend.load(compiled) diff --git a/projects/pt1/examples/torchscript_resnet18_all_output_types.py b/projects/pt1/examples/torchscript_resnet18_all_output_types.py index 70a920550..720db2cb6 100644 --- a/projects/pt1/examples/torchscript_resnet18_all_output_types.py +++ b/projects/pt1/examples/torchscript_resnet18_all_output_types.py @@ -13,8 +13,12 @@ resnet18.eval() module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10)) -module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") -print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10)) +module = torchscript.compile( + resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors" +) +print( + "LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10) +) # TODO: Debug why this is so slow. module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10)) diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py index e42828ed7..db281fc8e 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py @@ -4,10 +4,12 @@ from torch_mlir import torchscript model = models.resnet18(pretrained=True) model.eval() -data = torch.randn(2,3,200,200) +data = torch.randn(2, 3, 200, 200) out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" -module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False) +module = torchscript.compile( + model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False +) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index c68daf12d..af2af2de3 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -7,17 +7,22 @@ from transformers import BertForMaskedLM class BertTinyWrapper(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False) - + self.bert = BertForMaskedLM.from_pretrained( + "prajjwal1/bert-tiny", return_dict=False + ) + def forward(self, data): return self.bert(data)[0] + model = BertTinyWrapper() model.eval() data = torch.randint(30522, (2, 128)) out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" -module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True) +module = torchscript.compile( + model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True +) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/python/test/annotations-sugar.py b/projects/pt1/python/test/annotations-sugar.py index e540e84b9..3e85d1e47 100644 --- a/projects/pt1/python/test/annotations-sugar.py +++ b/projects/pt1/python/test/annotations-sugar.py @@ -11,18 +11,23 @@ from torch_mlir_e2e_test.annotations import annotate_args, export from torch_mlir.jit_ir_importer import ClassAnnotator from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations + class MmModule(torch.nn.Module): def __init__(self): super().__init__() + @export - @annotate_args([ - None, - ([3, 4], torch.float32, False), - ([4, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.float32, False), + ([4, 5], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.mm(lhs, rhs) + module = MmModule() annotator = ClassAnnotator() extract_annotations(module, torch.jit.script(module), annotator) diff --git a/projects/pt1/python/test/compile_api/already_scripted.py b/projects/pt1/python/test/compile_api/already_scripted.py index 7d9720727..f63b08b6a 100644 --- a/projects/pt1/python/test/compile_api/already_scripted.py +++ b/projects/pt1/python/test/compile_api/already_scripted.py @@ -26,6 +26,8 @@ print(torchscript.compile(scripted, example_args)) scripted = torch.jit.script(BasicModule()) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) + torchscript.compile( + scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)) + ) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/already_traced.py b/projects/pt1/python/test/compile_api/already_traced.py index 32f7b5653..ed77e09c0 100644 --- a/projects/pt1/python/test/compile_api/already_traced.py +++ b/projects/pt1/python/test/compile_api/already_traced.py @@ -8,10 +8,12 @@ import torch from torch_mlir import torchscript + class BasicModule(torch.nn.Module): def forward(self, x): return torch.ops.aten.sin(x) + example_arg = torch.ones(2, 3) example_args = torchscript.ExampleArgs.get(example_arg) @@ -23,6 +25,8 @@ print(torchscript.compile(traced, example_args)) traced = torch.jit.trace(BasicModule(), example_arg) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg)) + torchscript.compile( + traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg) + ) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/backend_legal_ops.py b/projects/pt1/python/test/compile_api/backend_legal_ops.py index 64ebf7a52..a1e9b9823 100644 --- a/projects/pt1/python/test/compile_api/backend_legal_ops.py +++ b/projects/pt1/python/test/compile_api/backend_legal_ops.py @@ -9,15 +9,24 @@ import torch from torch_mlir import torchscript + class AddmmModule(torch.nn.Module): def __init__(self): super().__init__() + def forward(self, x, y, z): return torch.ops.aten.addmm(x, y, z) + example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)] -print(torchscript.compile(AddmmModule(), example_args, - output_type="torch", backend_legal_ops=["aten.addmm"])) +print( + torchscript.compile( + AddmmModule(), + example_args, + output_type="torch", + backend_legal_ops=["aten.addmm"], + ) +) # CHECK-LABEL: @forward # CHECK: torch.aten.addmm diff --git a/projects/pt1/python/test/compile_api/basic.py b/projects/pt1/python/test/compile_api/basic.py index 0c516b620..075b022ad 100644 --- a/projects/pt1/python/test/compile_api/basic.py +++ b/projects/pt1/python/test/compile_api/basic.py @@ -9,12 +9,15 @@ import torch from torch_mlir import torchscript + class TanhModule(torch.nn.Module): def __init__(self): super().__init__() + def forward(self, x): return torch.ops.aten.tanh(x) + tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. @@ -35,16 +38,23 @@ print(torchscript.compile(TanhModule(), placeholder)) # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32> # Basic smoke test for the raw output type. -print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW)) +print( + torchscript.compile( + TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW + ) +) # CHECK: torch.nn_module { # CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule"> + class MmModule(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, lhs, rhs ): + + def forward(self, lhs, rhs): return torch.ops.aten.mm(lhs, rhs) + # N > 1 inputs. mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)] print(torchscript.compile(MmModule(), mm_example_inputs)) @@ -52,7 +62,10 @@ print(torchscript.compile(MmModule(), mm_example_inputs)) # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32> # Mixes Tensor's and TensorPlaceholder's. -mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] +mm_dynamic_inputs = [ + mm_example_inputs[0], + torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1]), +] print(torchscript.compile(MmModule(), mm_dynamic_inputs)) # CHECK-LABEL: @forward # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32> diff --git a/projects/pt1/python/test/compile_api/make_fx.py b/projects/pt1/python/test/compile_api/make_fx.py index ec859d86e..b29ee21cc 100644 --- a/projects/pt1/python/test/compile_api/make_fx.py +++ b/projects/pt1/python/test/compile_api/make_fx.py @@ -10,13 +10,21 @@ import torch from torch_mlir import torchscript + def simple(x): return x * x -example_input = torch.randn(1,) -graph = functorch.make_fx(simple)(torch.randn(1,)) + +example_input = torch.randn( + 1, +) +graph = functorch.make_fx(simple)( + torch.randn( + 1, + ) +) # Simplest case: One example argument. print(torchscript.compile(graph, example_input)) # CHECK-LABEL: @forward -# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> \ No newline at end of file +# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> diff --git a/projects/pt1/python/test/compile_api/output_type_spec.py b/projects/pt1/python/test/compile_api/output_type_spec.py index 92ed1e425..1767e23dd 100644 --- a/projects/pt1/python/test/compile_api/output_type_spec.py +++ b/projects/pt1/python/test/compile_api/output_type_spec.py @@ -9,15 +9,22 @@ import torch from torch_mlir import torchscript + class TanhModule(torch.nn.Module): def __init__(self): super().__init__() + def forward(self, x): return torch.ops.aten.tanh(x) + tanh_example_input = torch.ones(2, 3) -print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH)) +print( + torchscript.compile( + TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH + ) +) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch")) diff --git a/projects/pt1/python/test/compile_api/tracing.py b/projects/pt1/python/test/compile_api/tracing.py index bbf652f07..fff6a0e62 100644 --- a/projects/pt1/python/test/compile_api/tracing.py +++ b/projects/pt1/python/test/compile_api/tracing.py @@ -14,6 +14,7 @@ class TanhModule(torch.nn.Module): def forward(self, x): return torch.ops.aten.tanh(x) + tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. @@ -32,10 +33,12 @@ print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True)) # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # TensorPlaceholder support. -placeholder = torchscript.TensorPlaceholder.like( - tanh_example_input, dynamic_axes=[1]) -print(torchscript.compile(TanhModule(), [placeholder], - use_tracing=True, ignore_traced_shapes=True)) +placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) +print( + torchscript.compile( + TanhModule(), [placeholder], use_tracing=True, ignore_traced_shapes=True + ) +) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> @@ -55,18 +58,18 @@ except Exception as e: class DictModule(torch.nn.Module): def forward(self, x): - return x['a'] * 2.0 + return x["a"] * 2.0 try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) + torchscript.compile(DictModule(), {"a": torch.tensor(3.0)}, use_tracing=True) except Exception as e: print(e) try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) + torchscript.compile(DictModule(), [{"a": torch.tensor(3.0)}], use_tracing=True) except Exception as e: print(e) diff --git a/projects/pt1/python/test/debug/lockstep_basic.py b/projects/pt1/python/test/debug/lockstep_basic.py index 560ed965e..22d412bd4 100644 --- a/projects/pt1/python/test/debug/lockstep_basic.py +++ b/projects/pt1/python/test/debug/lockstep_basic.py @@ -15,8 +15,9 @@ from torch_mlir_e2e_test.debug.lockstep import make_lockstep_debug_backend @make_simple_dynamo_backend @make_lockstep_debug_backend() -def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): +def miscompile_div_as_mul_backend( + gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] +): # Copy `gm` and rewrite `div` to `mul`. new_g = torch.fx.Graph() new_g.output(new_g.graph_copy(gm.graph, {})) @@ -41,7 +42,7 @@ def f(x, y): return a, b, c -args = (torch.tensor([1., 2., 3.]), torch.tensor([4., 5., 6.])) +args = (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])) try: print(f(*args)) except AssertionError as e: diff --git a/projects/pt1/python/test/dynamo_fx_importer/basic.py b/projects/pt1/python/test/dynamo_fx_importer/basic.py index cea2f639f..f7957527d 100644 --- a/projects/pt1/python/test/dynamo_fx_importer/basic.py +++ b/projects/pt1/python/test/dynamo_fx_importer/basic.py @@ -11,7 +11,11 @@ import torch import torch.fx import torch._dynamo as dynamo from torch._dynamo.backends.common import aot_autograd -from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_compilation_context, set_model_name +from torch._functorch.aot_autograd import ( + make_boxed_compiler, + get_aot_compilation_context, + set_model_name, +) from torch_mlir.compiler_utils import TorchMlirCompilerError from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func @@ -19,8 +23,9 @@ from torch_mlir_e2e_test.configs.torchdynamo import jit @make_boxed_compiler -def my_aot_autograd_backend(gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): +def my_aot_autograd_backend( + gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] +): print(gm.graph) *_, model_name, nth_graph = get_aot_compilation_context() mlir_module = import_fx_graph_as_func(gm.graph, model_name) @@ -59,9 +64,7 @@ basic(torch.randn(3, 4)) # CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16> @dynamo.optimize(my_backend) def literals_list_device_int_none_dtype(): - return torch.ops.aten.randn([3, 4], - device=torch.device("cpu"), - dtype=torch.float16) + return torch.ops.aten.randn([3, 4], device=torch.device("cpu"), dtype=torch.float16) set_model_name("literals_list_device_int_none_dtype") diff --git a/projects/pt1/python/test/lit.cfg.py b/projects/pt1/python/test/lit.cfg.py index f0423c46f..0e6d132fa 100644 --- a/projects/pt1/python/test/lit.cfg.py +++ b/projects/pt1/python/test/lit.cfg.py @@ -19,65 +19,74 @@ from lit.llvm.subst import FindTool # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'TORCH_MLIR_PYTHON' +config.name = "TORCH_MLIR_PYTHON" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) -if 'TEST_SRC_PATH' in os.environ: - config.environment['TEST_SRC_PATH'] = os.environ['TEST_SRC_PATH'] +if "TEST_SRC_PATH" in os.environ: + config.environment["TEST_SRC_PATH"] = os.environ["TEST_SRC_PATH"] # path to our python operation library -config.environment['TEST_BUILD_PATH'] = os.path.join(config.torch_mlir_obj_root) +config.environment["TEST_BUILD_PATH"] = os.path.join(config.torch_mlir_obj_root) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.py'] +config.suffixes = [".py"] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test') +config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") # On Windows the path to python could contains spaces in which case it needs to # be provided in quotes. This is the equivalent of how %python is setup in # llvm/utils/lit/lit/llvm/config.py. if "Windows" in config.host_os: - config.python_executable = '"%s"' % (config.python_executable) + config.python_executable = '"%s"' % (config.python_executable) -config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) -config.substitutions.append(('%PYTHON', config.python_executable)) +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) +config.substitutions.append(("%PYTHON", config.python_executable)) -llvm_config.with_system_environment( - ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. -config.excludes = ['lit.cfg.py', 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] +config.excludes = [ + "lit.cfg.py", + "Inputs", + "Examples", + "CMakeLists.txt", + "README.txt", + "LICENSE.txt", +] if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))): - config.excludes.append("lazy_backend") + config.excludes.append("lazy_backend") # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test') -config.torch_mlir_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin') +config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") +config.torch_mlir_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin") # Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) -llvm_config.with_environment('PYTHONPATH', [ - os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'), +llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) +llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join(config.torch_mlir_python_packages_dir, "torch_mlir"), ], - append_path=True) + append_path=True, +) tool_dirs = [config.torch_mlir_tools_dir, config.llvm_tools_dir] tools = [ - 'torch-mlir-opt', + "torch-mlir-opt", ] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/projects/pt1/python/test/torchscript_e2e_test/basic.py b/projects/pt1/python/test/torchscript_e2e_test/basic.py index fa3f6f297..83a73900d 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/basic.py +++ b/projects/pt1/python/test/torchscript_e2e_test/basic.py @@ -40,5 +40,5 @@ def main(): report_results(results, set(), verbose=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py b/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py index 9b9091452..4aac65df8 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py +++ b/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py @@ -43,5 +43,5 @@ def main(): report_results(results, set(), verbose=True, config="myconfig") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/projects/pt1/python/test/torchscript_e2e_test/error_reports.py b/projects/pt1/python/test/torchscript_e2e_test/error_reports.py index f33212859..f6c949c3d 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/error_reports.py +++ b/projects/pt1/python/test/torchscript_e2e_test/error_reports.py @@ -122,7 +122,7 @@ class ErroneousModule(torch.nn.Module): @torch.jit.export def test_tensor_value_mismatch(self): if torch.jit.is_scripting(): - return torch.tensor([1., 2., 3.]) + return torch.tensor([1.0, 2.0, 3.0]) else: return torch.tensor([1.5, 2.5, 3.5]) @@ -132,9 +132,9 @@ class ErroneousModule(torch.nn.Module): @torch.jit.export def test_tensor_shape_mismatch(self): if torch.jit.is_scripting(): - return torch.tensor([1., 2.]) + return torch.tensor([1.0, 2.0]) else: - return torch.tensor([1., 2., 3.]) + return torch.tensor([1.0, 2.0, 3.0]) @register_test_case(module_factory=lambda: ErroneousModule()) @@ -157,5 +157,5 @@ def main(): report_results(results, set(), verbose=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py b/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py index a1c8c5adf..8991229f0 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py +++ b/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py @@ -51,5 +51,5 @@ def main(): report_results(results, set()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py b/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py index 3581c1b6d..b7609bc2c 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py +++ b/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py @@ -39,5 +39,5 @@ def main(): report_results(results, set(), verbose=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/projects/pt1/python/test/torchscript_e2e_test/submodule.py b/projects/pt1/python/test/torchscript_e2e_test/submodule.py index c88ad53b3..ae7ee878b 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/submodule.py +++ b/projects/pt1/python/test/torchscript_e2e_test/submodule.py @@ -12,6 +12,7 @@ from torch_mlir_e2e_test.reporting import report_results from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir_e2e_test.configs import TorchScriptTestConfig + class Submodule2(torch.nn.Module): def __init__(self): super().__init__() @@ -19,6 +20,7 @@ class Submodule2(torch.nn.Module): def forward(self, lhs, rhs): return torch.mm(lhs, rhs) + class Submodule(torch.nn.Module): def __init__(self): super().__init__() @@ -43,5 +45,5 @@ def main(): report_results(results, set()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py index be87e775d..fcea14dc1 100644 --- a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py +++ b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. import pdb + # This file implements a pure-Python importer from a restricted subset of # FX IR into MLIR. # @@ -77,8 +78,12 @@ def _verify_fx_graph_conforms_to_subset(g: torch.fx.Graph): if len(node.args) != len(node.target._schema.arguments): assert len(node.args) < len(node.target._schema.arguments) for i, argument in enumerate( - node.target._schema.arguments[len(node.args):]): - if not argument.has_default_value() and argument.name not in node.kwargs: + node.target._schema.arguments[len(node.args) :] + ): + if ( + not argument.has_default_value() + and argument.name not in node.kwargs + ): raise Exception( f"Unsupported: missing default value for argument {i} in schema for {node.target}" ) @@ -152,12 +157,12 @@ def _convert_dtype_to_mlir_type(dtype: torch.dtype) -> str: if dtype == torch.complex128: return "complex" - raise Exception(f"Unsupported dtype: {dtype}") def _import_fake_tensor_as_mlir_type( - fake_tensor: torch._subclasses.FakeTensor) -> ir.Type: + fake_tensor: torch._subclasses.FakeTensor, +) -> ir.Type: # TODO: Find story for how to get dynamically shaped tensors here. shape = ",".join(str(d) for d in fake_tensor.shape) dtype = _convert_dtype_to_mlir_type(fake_tensor.dtype) @@ -178,7 +183,8 @@ def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType: if node.op == "output": # TODO(DNS): Test this or add verifier that it can't happen. result_types = torch.fx.map_arg( - node.args[0], lambda n: _mlir_types_for_node(n)[0]) + node.args[0], lambda n: _mlir_types_for_node(n)[0] + ) # Note: We import directly to the backend contract -- multiple results # are modeled with func.func native multiple results rather than as a # singleton value / tuple. @@ -191,64 +197,40 @@ def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType: DTYPE_TO_INT = { # TODO(DNS): Fill in from AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS - torch.uint8: - 0, - torch.int8: - 1, - torch.int16: - 2, - torch.int32: - 3, - torch.int64: - 4, - torch.float16: - 5, - torch.float32: - 6, - torch.float64: - 7, + torch.uint8: 0, + torch.int8: 1, + torch.int16: 2, + torch.int32: 3, + torch.int64: 4, + torch.float16: 5, + torch.float32: 6, + torch.float64: 7, # torch.complex_half 8 - torch.complex64: - 9, - torch.complex128: - 10, - torch.bool: - 11, - torch.qint8: - 12, - torch.quint8: - 13, + torch.complex64: 9, + torch.complex128: 10, + torch.bool: 11, + torch.qint8: 12, + torch.quint8: 13, # torch.qint32 14 - torch.bfloat16: - 15, + torch.bfloat16: 15, } MEMORY_FORMAT_TO_INT = { # https://github.com/pytorch/pytorch/c10/core/MemoryFormat.h#L28 - torch.contiguous_format: - 0, - torch.preserve_format: - 1, - torch.channels_last: - 2, - torch.channels_last_3d: - 3, + torch.contiguous_format: 0, + torch.preserve_format: 1, + torch.channels_last: 2, + torch.channels_last_3d: 3, } LAYOUT_TO_INT = { # https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_layouts.cpp - torch.strided: - 0, - torch.sparse_coo: - 1, - torch.sparse_csr: - 2, - torch.sparse_csc: - 3, - torch.sparse_bsr: - 4, - torch.sparse_bsc: - 5, + torch.strided: 0, + torch.sparse_coo: 1, + torch.sparse_csr: 2, + torch.sparse_csc: 3, + torch.sparse_bsr: 4, + torch.sparse_bsc: 5, } @@ -264,7 +246,6 @@ def _mlir_location_for_node(node: torch.fx.Node) -> ir.Location: class _FXGraphImporter: - def __init__(self, g: torch.fx.Graph, func_name: str): self._g = g self._func_name = func_name @@ -277,7 +258,8 @@ class _FXGraphImporter: self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {} self._module = ir.Module.create(ir.Location.unknown()) self._module.operation.attributes[ - "torch.debug_module_name"] = ir.StringAttr.get(func_name) + "torch.debug_module_name" + ] = ir.StringAttr.get(func_name) function_type = _extract_function_type_from_graph(g) func = func_dialect.FuncOp( func_name, @@ -285,8 +267,7 @@ class _FXGraphImporter: loc=ir.Location.unknown(), # TODO: Can we do better? ip=ir.InsertionPoint(self._module.body), ) - self._body_block = ir.Block.create_at_start(func.body, - function_type.inputs) + self._body_block = ir.Block.create_at_start(func.body, function_type.inputs) def import_graph(self) -> ir.Module: with ir.InsertionPoint(self._body_block): @@ -294,14 +275,15 @@ class _FXGraphImporter: for node in self._g.nodes: with _mlir_location_for_node(node): if node.op == "placeholder": - self._env[( - node, 0 - )] = self._body_block.arguments[num_placeholders_seen] + self._env[(node, 0)] = self._body_block.arguments[ + num_placeholders_seen + ] num_placeholders_seen += 1 if node.op == "call_function": if node.target is operator.getitem: - self._env[(node, 0)] = self._env[(node.args[0], - node.args[1])] + self._env[(node, 0)] = self._env[ + (node.args[0], node.args[1]) + ] else: self._import_op_overload_call(node) if node.op == "output": @@ -309,9 +291,7 @@ class _FXGraphImporter: # a tuple of return values (without the single-element special # case) # DNS: Test or verify no literals as results. - operands = [ - self._import_argument(arg) for arg in node.args[0] - ] + operands = [self._import_argument(arg) for arg in node.args[0]] func_dialect.ReturnOp(operands) return self._module @@ -328,7 +308,8 @@ class _FXGraphImporter: # DNS: Unregistered ops assert ir.Context.current.is_registered_operation( - mlir_op_name), f"Unregistered operation: {mlir_op_name}" + mlir_op_name + ), f"Unregistered operation: {mlir_op_name}" # Construct the Operation. result_types = _mlir_types_for_node(node) @@ -352,9 +333,9 @@ class _FXGraphImporter: for i, value in enumerate(operation.results): self._env[(node, i)] = value - def _import_argument(self, - arg: torch.fx.node.Argument, - expected_type_for_literal=None) -> ir.Value: + def _import_argument( + self, arg: torch.fx.node.Argument, expected_type_for_literal=None + ) -> ir.Value: """Import an FX `Argument`, which is analogous to an MLIR `Value`. Args: @@ -371,22 +352,21 @@ class _FXGraphImporter: assert expected_type_for_literal is not None return self._import_literal(arg, expected_type_for_literal) - def _import_literal(self, arg: torch.fx.node.Argument, - expected_type) -> ir.Value: + def _import_literal(self, arg: torch.fx.node.Argument, expected_type) -> ir.Value: if arg is None: return torch_dialect.ConstantNoneOp().result if isinstance(expected_type, torch.OptionalType): return self._import_argument(arg, expected_type.getElementType()) if isinstance(arg, bool): return torch_dialect.ConstantBoolOp( - ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)).result + ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg) + ).result if isinstance(arg, int): return torch_dialect.ConstantIntOp( - ir.IntegerAttr.get(ir.IntegerType.get_signless(64), - arg)).result + ir.IntegerAttr.get(ir.IntegerType.get_signless(64), arg) + ).result if isinstance(arg, float): - return torch_dialect.ConstantFloatOp( - ir.FloatAttr.get_f64(arg)).result + return torch_dialect.ConstantFloatOp(ir.FloatAttr.get_f64(arg)).result if isinstance(arg, str): return torch_dialect.ConstantStrOp(ir.StringAttr.get(arg)).result if isinstance(arg, torch.dtype): @@ -394,12 +374,10 @@ class _FXGraphImporter: return self._import_argument(DTYPE_TO_INT[arg], expected_type) if isinstance(arg, torch.device): # TODO(DNS): Device index? arg.index - return torch_dialect.ConstantDeviceOp(ir.StringAttr.get( - arg.type)).result + return torch_dialect.ConstantDeviceOp(ir.StringAttr.get(arg.type)).result if isinstance(arg, torch.memory_format): assert isinstance(expected_type, torch.IntType) - return self._import_argument(MEMORY_FORMAT_TO_INT[arg], - expected_type) + return self._import_argument(MEMORY_FORMAT_TO_INT[arg], expected_type) if isinstance(arg, torch.layout): assert isinstance(expected_type, torch.IntType) return self._import_argument(LAYOUT_TO_INT[arg], expected_type) @@ -409,14 +387,14 @@ class _FXGraphImporter: if isinstance(element_type, torch.TensorType): assert all( torch.fx.node.map_aggregate( - arg, lambda a: _is_valid_meta_val(a.meta.get("val")))) + arg, lambda a: _is_valid_meta_val(a.meta.get("val")) + ) + ) els = [self._env[e, 0] for e in arg] else: element_type = _torch_type_to_mlir_type(element_type) - els = [ - self._import_argument(e, element_type) for e in arg - ] + els = [self._import_argument(e, element_type) for e in arg] # import pydevd_pycharm # pydevd_pycharm.settrace('localhost', port=8888, stdoutToServer=True, stderrToServer=True) diff --git a/projects/pt1/python/torch_mlir/_torch_mlir_custom_op_example/__init__.py b/projects/pt1/python/torch_mlir/_torch_mlir_custom_op_example/__init__.py index 6a0eedf1d..bc14a8c01 100644 --- a/projects/pt1/python/torch_mlir/_torch_mlir_custom_op_example/__init__.py +++ b/projects/pt1/python/torch_mlir/_torch_mlir_custom_op_example/__init__.py @@ -3,6 +3,5 @@ import torch # Register _torch_mlir_custom_op_example.identity as a side-effect of importing. current_dir = os.path.dirname(os.path.abspath(__file__)) -lib = os.path.join(*[current_dir, 'libtorch_mlir_custom_op_example.so']) +lib = os.path.join(*[current_dir, "libtorch_mlir_custom_op_example.so"]) torch.ops.load_library(lib) - diff --git a/projects/pt1/python/torch_mlir/_version.py b/projects/pt1/python/torch_mlir/_version.py index 7ebc01dd9..fd2c3157e 100644 --- a/projects/pt1/python/torch_mlir/_version.py +++ b/projects/pt1/python/torch_mlir/_version.py @@ -6,6 +6,7 @@ from packaging import version import torch + def torch_version_for_comparison(): # Ignore +cpu, +cu117m, etc. in comparisons return version.parse(torch.__version__.split("+", 1)[0]) diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py index 34c9e6190..5265fed49 100755 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py @@ -4,20 +4,20 @@ import sys import os -if __name__ == '__main__': +if __name__ == "__main__": path = sys.argv[1] # dummy script path file_name = sys.argv[2] # dummy script - contents = ''' -# This file was automatically generated due to LTC being disabled in build. - + contents = """ +# This file was automatically generated due to LTC being disabled in build. + class LazyTensorCoreTestConfig: def __init__(self): assert False, "LTC is not enabled. Check the value of `TORCH_MLIR_ENABLE_LTC`" - ''' + """ if not os.path.exists(path): os.makedirs(path) - with open(os.path.join(path, file_name + '.py'), 'w') as file: + with open(os.path.join(path, file_name + ".py"), "w") as file: file.write(contents) diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index 1b78b2a06..2c339be98 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -13,6 +13,7 @@ from torch._dynamo.backends.common import aot_autograd import functorch import warnings + # https://github.com/pytorch/pytorch/issues/89064 warnings.filterwarnings("ignore", module="torch.jit._check") @@ -91,8 +92,7 @@ def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool: did_convert_list_to_tuple = False for node in gm.graph.nodes: if node.op == "output": - assert len(node.args) == 1, \ - "Output node must have a single argument" + assert len(node.args) == 1, "Output node must have a single argument" node_arg = node.args[0] if isinstance(node_arg, tuple): if len(node_arg) == 1: @@ -106,7 +106,7 @@ def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool: did_convert_list_to_tuple = True break else: - node.args= (tuple(node_arg),) + node.args = (tuple(node_arg),) did_convert_list_to_tuple = True break @@ -129,10 +129,12 @@ def make_simple_dynamo_backend(user_backend): Returns: A function with the signature used by TorchDynamo backends. """ - def wrapper_backend(gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): - did_unwrap_single_element, did_convert_list_to_tuple = \ - _adjust_calling_convention(gm) + + def wrapper_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + ( + did_unwrap_single_element, + did_convert_list_to_tuple, + ) = _adjust_calling_convention(gm) strip_overloads(gm) user_callable = user_backend(gm, example_inputs) @@ -147,6 +149,9 @@ def make_simple_dynamo_backend(user_backend): if did_convert_list_to_tuple: result = list(result) return result + return dynamo_callable - return aot_autograd(fw_compiler=wrapper_backend, - decompositions=_get_decomposition_table) + + return aot_autograd( + fw_compiler=wrapper_backend, decompositions=_get_decomposition_table + ) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py b/projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py index b5a49561a..00574548c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py @@ -18,7 +18,7 @@ from .._mlir_libs._jit_ir_importer import * from ..dialects import torch as _unused_torch_dialect __all__ = [ - "debug_trace_to_stderr", - "ModuleBuilder", - "ClassAnnotator", + "debug_trace_to_stderr", + "ModuleBuilder", + "ClassAnnotator", ] diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/library_generator.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/library_generator.py index 6cd19643a..5285aa0d0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/library_generator.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/library_generator.py @@ -15,24 +15,31 @@ from torch_mlir.passmanager import PassManager from .registry import Registry + def all_integer_dtypes() -> List[int]: return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + def is_integer_dtype(dtype: int) -> bool: return dtype in all_integer_dtypes() + def all_complex_dtypes() -> List[int]: return [torch.complex64, torch.complex128] + def is_complex_dtype(dtype: int) -> bool: return dtype in all_complex_dtypes() + def all_float_dtypes() -> List[int]: return [torch.float16, torch.bfloat16, torch.float32, torch.float64] + def is_float_dtype(dtype: int) -> bool: return dtype in all_float_dtypes() + def get_priority_of_dtype(dtype: int) -> int: # If a loop is used to iterate over a list of sorted dtypes, TorchScript # produces a loop with INT64_MAX max trip count, which causes problems @@ -64,6 +71,7 @@ def get_priority_of_dtype(dtype: int) -> int: return 11 assert False, "Cannot determine priority of dtype" + def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int: # This is hacky. `NumToTensor` is the only PyTorch op for scalars # that when `jit.script`ed converts a float scalar to a tensor @@ -83,6 +91,7 @@ def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int: # op. return torch.ops.prim.NumToTensor(scalar).dtype + # When we import into torch-mlir, only the calls to # `__torch_mlir_internal_promote_dtypes` are used to generate the # `torch.promote_dtypes` ops. Therefore, to avoid generating extra @@ -97,23 +106,29 @@ def _get_scalar_with_dtype(dtype: torch.dtype) -> Union[int, float]: else: raise ValueError(f"Unhandled dtype: {dtype}") + @torch.jit.ignore -def _promote_scalar_tensor(scalar_dtype: torch.dtype, tensor_rank: int, - tensor_dtype: torch.dtype) -> torch.dtype: +def _promote_scalar_tensor( + scalar_dtype: torch.dtype, tensor_rank: int, tensor_dtype: torch.dtype +) -> torch.dtype: scalar = _get_scalar_with_dtype(scalar_dtype) tensor = torch.rand([1] * tensor_rank).to(tensor_dtype) return torch.result_type(scalar, tensor) + @torch.jit.ignore -def _promote_tensor_tensor(lhs_rank: int, lhs_dtype: torch.dtype, - rhs_rank: int, rhs_dtype: torch.dtype) -> torch.dtype: +def _promote_tensor_tensor( + lhs_rank: int, lhs_dtype: torch.dtype, rhs_rank: int, rhs_dtype: torch.dtype +) -> torch.dtype: lhs_tensor = torch.rand([1] * lhs_rank).to(lhs_dtype) rhs_tensor = torch.rand([1] * rhs_rank).to(rhs_dtype) return torch.result_type(lhs_tensor, rhs_tensor) + @torch.jit.ignore -def _promote_scalar_scalar(lhs_dtype: torch.dtype, - rhs_dtype: torch.dtype) -> torch.dtype: +def _promote_scalar_scalar( + lhs_dtype: torch.dtype, rhs_dtype: torch.dtype +) -> torch.dtype: # When `torch.result_type` is used on two scalars, the result # dtype is the dtype one would expect for an op with signature # (Scalar, Scalar) -> (Tensor). However, once a module gets @@ -122,15 +137,17 @@ def _promote_scalar_scalar(lhs_dtype: torch.dtype, # dtype, we use the tensor-tensor promotion rules. return _promote_tensor_tensor(0, lhs_dtype, 0, rhs_dtype) -def promote_dtypes(ranks: List[Optional[int]], - dtypes: List[torch.dtype]) -> torch.dtype: - """Apply PyTorch dtype promotion rules and return the result type. - """ + +def promote_dtypes( + ranks: List[Optional[int]], dtypes: List[torch.dtype] +) -> torch.dtype: + """Apply PyTorch dtype promotion rules and return the result type.""" return __torch_mlir_internal_promote_dtypes(ranks, dtypes) -def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]], - dtypes: List[torch.dtype] - ) -> torch.dtype: + +def __torch_mlir_internal_promote_dtypes( + ranks: List[Optional[int]], dtypes: List[torch.dtype] +) -> torch.dtype: """Apply PyTorch dtype promotion rules and return the result type. This function serves two purposes: @@ -145,18 +162,18 @@ def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]], if lhs_optional_rank is None and rhs_optional_rank is None: lhs_dtype = _promote_scalar_scalar(lhs_dtype, rhs_dtype) elif lhs_optional_rank is None and rhs_optional_rank is not None: - lhs_dtype = _promote_scalar_tensor( - lhs_dtype, rhs_optional_rank, rhs_dtype) + lhs_dtype = _promote_scalar_tensor(lhs_dtype, rhs_optional_rank, rhs_dtype) lhs_optional_rank = rhs_optional_rank elif lhs_optional_rank is not None and rhs_optional_rank is None: - lhs_dtype = _promote_scalar_tensor( - rhs_dtype, lhs_optional_rank, lhs_dtype) + lhs_dtype = _promote_scalar_tensor(rhs_dtype, lhs_optional_rank, lhs_dtype) elif lhs_optional_rank is not None and rhs_optional_rank is not None: lhs_dtype = _promote_tensor_tensor( - lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype) + lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype + ) lhs_optional_rank = max(lhs_optional_rank, rhs_optional_rank) return lhs_dtype + def not_present_in_registry(f): """Decorator for abstract interpretation functions not present in the registry. @@ -175,6 +192,7 @@ def not_present_in_registry(f): f._not_present_in_registry = None return f + def _verify_signature_matches_registry(f, registry: Registry): source = inspect.getsource(f) signature = None @@ -183,7 +201,9 @@ def _verify_signature_matches_registry(f, registry: Registry): signature = line break assert signature is not None, f"Could not find signature for {f.__name__}" - assert "〡" in signature, f"Malformed signature {signature}. Signature missing the character `〡`" + assert ( + "〡" in signature + ), f"Malformed signature {signature}. Signature missing the character `〡`" function_name, function_kind = f.__name__.split("〡") atoms = function_name.split("〇") if len(atoms) == 2: @@ -203,7 +223,10 @@ def _verify_signature_matches_registry(f, registry: Registry): else: raise ValueError(f"Invalid Op signature function kind: '{function_kind}'") if signature != expected_signature: - raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}") + raise ValueError( + f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}" + ) + def generate_library(functions: Dict[str, Any]) -> str: """Convert all op functions in `functions` into MLIR.""" @@ -245,11 +268,13 @@ def generate_library(functions: Dict[str, Any]) -> str: # the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}` # The extra namespaces are not part of the abstract interpretation # function name, so here we simply drop the extra namespaces. - namespace = fr"(?:{name}\.)" + namespace = rf"(?:{name}\.)" - asm = re.sub(fr'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"', - fr'@"__torch_mlir_\3_fn.\1{circle}\2"', - asm) + asm = re.sub( + rf'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"', + rf'@"__torch_mlir_\3_fn.\1{circle}\2"', + asm, + ) # Put the `〇` back to a regular `.`. asm = asm.replace(codecs.decode(circle, "unicode_escape"), ".") diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py index cf24be8cc..ec8317270 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py @@ -1,4 +1,3 @@ - # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -15,13 +14,17 @@ import difflib from .utils import TextEmitter # Note that this utility exists only in the c-extension. -from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error +from torch_mlir._mlir_libs._jit_ir_importer import ( + get_registered_ops, +) # pytype: disable=import-error + def _rename_python_keyword_parameter_name(parameter_name: str) -> str: if parameter_name == "from": - parameter_name = "from_" # Avoid using a Python keyword. + parameter_name = "from_" # Avoid using a Python keyword. return parameter_name + def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: default = "" if "default_debug" in arg: @@ -40,8 +43,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: if default_list == "[]": default_debug = "()" else: - default_debug = default_list.replace( - "[", "(").replace("]", ",)") + default_debug = default_list.replace("[", "(").replace("]", ",)") elif arg["pytype"] == "str": default_debug = repr(arg["default_debug"]).replace("'", '"') else: @@ -49,6 +51,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: default = f" = {default_debug}" return default + def _pytype_to_fn_pytype_common(pytype: str) -> str: if "number" in pytype: return pytype.replace("number", "Union[int, float, complex]") @@ -65,6 +68,7 @@ def _pytype_to_fn_pytype_common(pytype: str) -> str: return "Any" return pytype + def _pytype_to_shape_fn_pytype(pytype: str) -> str: """Convert a JitOperator pytype to the type relevant in shape functions. @@ -86,6 +90,7 @@ def _pytype_to_shape_fn_pytype(pytype: str) -> str: return pytype.replace("Tensor", "List[int]") return _pytype_to_fn_pytype_common(pytype) + def _pytype_to_dtype_fn_pytype(pytype: str) -> str: """Convert a JitOperator pytype to the type relevant in dtype functions. @@ -97,13 +102,15 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str: return pytype.replace("Tensor", "Tuple[int, int]") return _pytype_to_fn_pytype_common(pytype) + def _pytype_to_decomposition_fn_pytype(pytype: str) -> str: - """Convert a JitOperator pytype to the type relevant in decomposition functions. - """ + """Convert a JitOperator pytype to the type relevant in decomposition functions.""" return _pytype_to_fn_pytype_common(pytype) + class JitOperator: """Information about a single registered `torch::jit::Operator`""" + def __init__(self, op_info: "OP_INFO_DICT"): """Create a JitOperator from the raw OP_INFO_DICT extracted from the PyTorch JIT operator registry. @@ -170,6 +177,7 @@ class JitOperator: are useful in the repr for cross referencing, and it's useful to have them in a single point of truth. """ + def uppercase_first_letter(s): if not s: return s @@ -184,15 +192,19 @@ class JitOperator: for op_name_atom in op_name_atoms: for s in op_name_atom.split("_"): op_class_name_atoms.append(s if s else "_") - cpp_class_name = "".join( - uppercase_first_letter(s) for s in op_class_name_atoms) + "Op" + cpp_class_name = ( + "".join(uppercase_first_letter(s) for s in op_class_name_atoms) + "Op" + ) # Disallow leading underscores in C++ to avoid illegal names. cpp_class_name = cpp_class_name.lstrip("_") return op_name, cpp_class_name - def _get_function_signature(self, function_kind: str, - parameter_decl_builder: Callable[["SIG_ATTR_TYPE"], str], - ret_decl_builder: Callable[["SIG_ATTR_TYPE"], str]) -> str: + def _get_function_signature( + self, + function_kind: str, + parameter_decl_builder: Callable[["SIG_ATTR_TYPE"], str], + ret_decl_builder: Callable[["SIG_ATTR_TYPE"], str], + ) -> str: mlir_op_name, _ = self.get_mlir_names() # Replace `.` with a valid Python identifier character. # `〇` vaguely looks like `.`. @@ -219,6 +231,7 @@ class JitOperator: ops have extra default arguments and stuff that are tedious to write out right. """ + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: pytype = _pytype_to_shape_fn_pytype(arg["pytype"]) default = _get_default_value(arg) @@ -226,10 +239,11 @@ class JitOperator: return f"{parameter_name}: {pytype}{default}" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - return _pytype_to_shape_fn_pytype(arg["pytype"]) + return _pytype_to_shape_fn_pytype(arg["pytype"]) return self._get_function_signature( - "shape", parameter_decl_builder, ret_decl_builder) + "shape", parameter_decl_builder, ret_decl_builder + ) def get_dtype_function_signature(self): """Gets the Python function signature for this op's dtype function. @@ -239,6 +253,7 @@ class JitOperator: ops have extra default arguments and stuff that are tedious to write out right. """ + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: pytype = _pytype_to_dtype_fn_pytype(arg["pytype"]) default = _get_default_value(arg) @@ -257,7 +272,8 @@ class JitOperator: return _pytype_to_dtype_fn_pytype(arg["pytype"]) return self._get_function_signature( - "dtype", parameter_decl_builder, ret_decl_builder) + "dtype", parameter_decl_builder, ret_decl_builder + ) def get_decomposition_function_signature(self): """Gets the Python function signature for this op's decomposition function. @@ -267,6 +283,7 @@ class JitOperator: ops have extra default arguments and stuff that are tedious to write out right. """ + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"]) default = _get_default_value(arg) @@ -277,7 +294,8 @@ class JitOperator: return _pytype_to_decomposition_fn_pytype(arg["pytype"]) return self._get_function_signature( - "decomposition", parameter_decl_builder, ret_decl_builder) + "decomposition", parameter_decl_builder, ret_decl_builder + ) def get_has_value_semantics_function_signature(self): """Gets the Python function signature for this op's has_value_semantics function. @@ -287,6 +305,7 @@ class JitOperator: ops have extra default arguments and stuff that are tedious to write out right. """ + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: return "" @@ -294,7 +313,8 @@ class JitOperator: return "" return self._get_function_signature( - "has_value_semantics", parameter_decl_builder, ret_decl_builder) + "has_value_semantics", parameter_decl_builder, ret_decl_builder + ) def __repr__(self): f = io.StringIO() @@ -318,7 +338,9 @@ class JitOperator: p(f"is_mutable = {self.is_mutable}") if any(ret["type"] == "Tensor" for ret in self.returns): p(f"shape_function_signature = {self.get_shape_function_signature()}") - p(f"decomposition_function_signature = {self.get_decomposition_function_signature()}") + p( + f"decomposition_function_signature = {self.get_decomposition_function_signature()}" + ) if any(ret["type"] in ["Tensor", "Scalar"] for ret in self.returns): p(f"dtype_function_signature = {self.get_dtype_function_signature()}") @@ -354,7 +376,9 @@ class JitOperator: # Note that this is different from MLIR's NoSideEffect which is much # stronger (for example, it cannot be applied to ops that might emit errors # when operand shapes mismatch). - if any("alias_info" in x for x in itertools.chain(self.arguments, self.returns)): + if any( + "alias_info" in x for x in itertools.chain(self.arguments, self.returns) + ): return False # It seems the FunctionSchema of "prim::unchecked_cast : (t) -> (t)" has # incorrect alias information. The result can alias with other tensors @@ -363,8 +387,10 @@ class JitOperator: return False # The `is` operator compares object identity, so it does not have # value semantics. - if self.unique_key in ("aten::__is__ : (t1, t2) -> (bool)", - "aten::__isnot__ : (t1, t2) -> (bool)"): + if self.unique_key in ( + "aten::__is__ : (t1, t2) -> (bool)", + "aten::__isnot__ : (t1, t2) -> (bool)", + ): return False return True @@ -390,6 +416,7 @@ class JitOperator: class Registry: """An indexed collection of JitOperators""" + def __init__(self, operators: List[JitOperator]): self.by_unique_key = {} self.by_triple = {} @@ -434,4 +461,3 @@ SIGLIST_TYPE = List[SIG_ATTR_TYPE] # - Tuple[str] (e.g. {'name': ('aten::size', 'int')} ) # - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} ) OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]] - diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py index 30a7d6387..cbeb38d66 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py @@ -45,6 +45,7 @@ from torch import Tensor # The typical iteration flow is to add invocations to the list and then re-run # `build_tools/update_abstract_interp_lib.sh` to re-run the tests. + class TensorOfShape: """Symbolic placeholder for a tensor argument to an operation. @@ -60,30 +61,40 @@ class TensorOfShape: This class also tracks a dtype of the tensor, since some ops require a specific dtype. """ - def __init__(self, *shape: int, dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None): + + def __init__( + self, + *shape: int, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ): self.shape = list(shape) self.dtype = dtype self.device = "meta" if device is None else device + def __repr__(self): args_str = ", ".join(repr(x) for x in self.shape) return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})" + def LongTensorOfShape(*args, **kwargs): """Helper for indicating a TensorOfShape with integer type.""" return TensorOfShape(*args, **kwargs, dtype=torch.long) + def NonZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None): """Helper for indicating a non-zero dim tensor with custom type.""" return TensorOfShape(1, dtype=dtype, device=device) + def ZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None): """Helper for indicating a zero dim tensor with custom type.""" return TensorOfShape(dtype=dtype, device=device) + def _recursively_transform_tensor_args( - o: Any, - tensor_transformer: Callable[[TensorOfShape], Any]) -> Any: + o: Any, tensor_transformer: Callable[[TensorOfShape], Any] +) -> Any: """Replace `TensorOfShape` with the result of `tensor_transformer`""" if o is None or isinstance(o, (float, int, str)): return o @@ -92,9 +103,12 @@ def _recursively_transform_tensor_args( if isinstance(o, list): return [_recursively_transform_tensor_args(x, tensor_transformer) for x in o] if isinstance(o, tuple): - return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o) + return tuple( + _recursively_transform_tensor_args(x, tensor_transformer) for x in o + ) raise Exception(f"Unhandled type {type(o)}") + class Invocation: """Representation of a single op invocation (i.e. list of args to the op). @@ -111,6 +125,7 @@ class Invocation: exception for greater precision when interpreting errors raised during testing. """ + def __init__(self, *args: Any, **kwargs: Any): self.args = list(args) # We assume kwargs don't contain tensors, so they don't need any @@ -134,14 +149,12 @@ class Invocation: # are ok since they make it a bit easier to write some shape # functions. tensor_transformer = lambda o: list(o.shape) - return _recursively_transform_tensor_args( - self.args, tensor_transformer) + return _recursively_transform_tensor_args(self.args, tensor_transformer) def to_dtype_function_args(self): """Gets positional arguments appropriate for a dtype function.""" tensor_transformer = lambda o: (len(o.shape), o.dtype) - return _recursively_transform_tensor_args( - self.args, tensor_transformer) + return _recursively_transform_tensor_args(self.args, tensor_transformer) def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" @@ -155,6 +168,7 @@ class Invocation: kwargs_str = ", " + ", ".join(f"{k}={v}" for k, v in self.kwargs.items()) return f"Invocation({args_str}{kwargs_str})" + class ErrorInvocation(Invocation): """An Invocation that raises an exception. @@ -165,9 +179,11 @@ class ErrorInvocation(Invocation): spurioiusly make the two appear to "agree" that an exception needs to be raised). """ + def is_expected_to_raise_exception(self) -> bool: return True + def _normalize_multiple_results_to_list(t: Any): """Returns a flat list of results. @@ -182,9 +198,13 @@ def _normalize_multiple_results_to_list(t: Any): return [t] raise ValueError(f"Unexpected type {type(t)}") + def _report(f, invocation: Invocation, error_message: str): - fn_type = f.__name__.split("〡")[-1] - raise ValueError(f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}") + fn_type = f.__name__.split("〡")[-1] + raise ValueError( + f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}" + ) + def _get_fn_and_golden_results(f, invocation: List[Invocation]): """Run the invocation on the library function and torch op. @@ -201,36 +221,64 @@ def _get_fn_and_golden_results(f, invocation: List[Invocation]): op = getattr(getattr(getattr(torch.ops, ns), unqual), overload) fn_error, op_error, fn_results, golden_results = None, None, None, None try: - fn_results = _normalize_multiple_results_to_list(f( - *(getattr(invocation, f"to_{fn_type}_function_args")()), - **invocation.kwargs)) + fn_results = _normalize_multiple_results_to_list( + f( + *(getattr(invocation, f"to_{fn_type}_function_args")()), + **invocation.kwargs, + ) + ) except Exception as e: fn_error = f"{e}" try: - golden_results = _normalize_multiple_results_to_list(op( - *invocation.to_real_op_args(), - **invocation.kwargs)) + golden_results = _normalize_multiple_results_to_list( + op(*invocation.to_real_op_args(), **invocation.kwargs) + ) except Exception as e: op_error = f"{e}" # Check for error behavior. if invocation.is_expected_to_raise_exception(): if fn_error is None and op_error is None: - _report(f, invocation, f"Expected to raise an exception, but neither {fn_type} function or op raised an exception") + _report( + f, + invocation, + f"Expected to raise an exception, but neither {fn_type} function or op raised an exception", + ) if fn_error is None: - _report(f, invocation, f"Op raised error {op_error!r}, but shape/dtype function did not.") + _report( + f, + invocation, + f"Op raised error {op_error!r}, but shape/dtype function did not.", + ) if op_error is None: - _report(f, invocation, f"{fn_type} function raised error {fn_error!r}, but op did not.") + _report( + f, + invocation, + f"{fn_type} function raised error {fn_error!r}, but op did not.", + ) else: if fn_error is not None and op_error is not None: - _report(f, invocation, f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.") + _report( + f, + invocation, + f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.", + ) if fn_error is not None: - _report(f, invocation, f"{fn_type} function raised error {fn_error!r} but op did not raise any error.") + _report( + f, + invocation, + f"{fn_type} function raised error {fn_error!r} but op did not raise any error.", + ) if op_error is not None: - _report(f, invocation, f"Op raised error {op_error!r} but {fn_type} function did not raise any error.") + _report( + f, + invocation, + f"Op raised error {op_error!r} but {fn_type} function did not raise any error.", + ) return fn_results, golden_results + def check_shape_function(invocations: List[Invocation]): """Decorator that automatically tests a shape function. @@ -238,6 +286,7 @@ def check_shape_function(invocations: List[Invocation]): `〇` instead of `.`, is tested against the corresponding op in `torch.ops.*` function using the given invocations. """ + def decorator(f): for invocation in invocations: result_shapes, golden_results = _get_fn_and_golden_results(f, invocation) @@ -245,18 +294,34 @@ def check_shape_function(invocations: List[Invocation]): continue # Check for matching results. if len(result_shapes) != len(golden_results): - _report(f, invocation, f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}") + _report( + f, + invocation, + f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}", + ) for result_shape, golden_result in zip(result_shapes, golden_results): result_rank = len(result_shape) golden_rank = len(golden_result.shape) if result_rank != golden_rank: - _report(f, invocation, f"Expected result rank {golden_rank}, got {result_rank}") - for dimension_size, golden_dimension_size in zip(result_shape, golden_result.shape): + _report( + f, + invocation, + f"Expected result rank {golden_rank}, got {result_rank}", + ) + for dimension_size, golden_dimension_size in zip( + result_shape, golden_result.shape + ): if dimension_size != golden_dimension_size: - _report(f, invocation, f"Expected result shape {golden_result.shape}, got {result_shape}") + _report( + f, + invocation, + f"Expected result shape {golden_result.shape}, got {result_shape}", + ) return f + return decorator + @torch.jit.script def _convert_dtype_to_int(dtype: torch.dtype) -> int: """Convert a PyTorch `dtype` into its underlying `int` representation. @@ -266,6 +331,7 @@ def _convert_dtype_to_int(dtype: torch.dtype) -> int: """ return dtype + def check_dtype_function(invocations: List[Invocation]): """Decorator that automatically tests a dtype function. @@ -273,6 +339,7 @@ def check_dtype_function(invocations: List[Invocation]): `〇` instead of `.`, is tested against the corresponding op in `torch.ops.*` function using the given invocations. """ + def decorator(f): for invocation in invocations: result_dtypes, golden_results = _get_fn_and_golden_results(f, invocation) @@ -280,7 +347,11 @@ def check_dtype_function(invocations: List[Invocation]): continue if len(result_dtypes) != len(golden_results): - _report(f, invocation, f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}") + _report( + f, + invocation, + f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}", + ) for result_dtype, golden_result in zip(result_dtypes, golden_results): if isinstance(golden_result, torch.Tensor): golden_dtype = golden_result.dtype @@ -294,7 +365,14 @@ def check_dtype_function(invocations: List[Invocation]): # support returning the default `int` value, the comparisons of # the result and golden dtypes are done using their underlying # `int` representation. - if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype): - _report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}") + if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int( + golden_dtype + ): + _report( + f, + invocation, + f"Expected result dtype {golden_dtype}, got {result_dtype}", + ) return f + return decorator diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index dd14c8bdb..7d50923f7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -69,30 +69,36 @@ def get_ods_type(type: str, non_value: bool, *, is_result: bool = False): type = "Tensor?" # TODO: Increase precision on dict type modeling. if type.startswith("Dict("): - type = "Dict" + type = "Dict" if non_value: - ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get(type) or TORCH_TYPE_TO_ODS_TYPE.get(type) + ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get( + type + ) or TORCH_TYPE_TO_ODS_TYPE.get(type) else: ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type) if ods_type is None: raise Exception( - f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!") + f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!" + ) return ods_type def _name_thunk() -> None: - # Strictly exists for _get_main_module_name to harvest its __module__. - pass + # Strictly exists for _get_main_module_name to harvest its __module__. + pass + + def _get_main_module_name() -> str: # If a Python module is loaded interactively or as part of a module # directory, it uses a BuiltinImporter. If loaded from a file, it uses # the SourceFileLoader. These two objects have different attributes. loader = sys.modules["__main__"].__loader__ try: - return loader.name # pytype: disable=attribute-error + return loader.name # pytype: disable=attribute-error except AttributeError: return _name_thunk.__module__ + ODS_BANNER = f"""//===-------------------------------------------------------*- tablegen -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. @@ -117,10 +123,15 @@ ODS_BANNER = f"""//===-------------------------------------------------------*- """ -def raw_emit_op(operator: JitOperator, - emitter_td: TextEmitter, - *, traits: List[str], - has_folder: bool, has_canonicalizer: bool, has_verifier: bool): +def raw_emit_op( + operator: JitOperator, + emitter_td: TextEmitter, + *, + traits: List[str], + has_folder: bool, + has_canonicalizer: bool, + has_verifier: bool, +): """Emit the ODS for a JitOperator to a textual file. This is the lowest level of emission and is responsible for low-level @@ -138,8 +149,7 @@ def raw_emit_op(operator: JitOperator, def generic_result_name(i): return "result" + (str(i) if multiple_results else "") - p_td( - f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [") + p_td(f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [") with emitter_td.indent(): with emitter_td.indent(): p_td(",\n".join(traits)) @@ -153,20 +163,28 @@ def raw_emit_op(operator: JitOperator, if operator.is_vararg: p_td("Variadic:$operands") else: - p_td(",\n".join([ - f"""{get_ods_type(arg["type"], is_non_value_op)}:${arg["name"]}""" - for arg in operator.arguments - ])) + p_td( + ",\n".join( + [ + f"""{get_ods_type(arg["type"], is_non_value_op)}:${arg["name"]}""" + for arg in operator.arguments + ] + ) + ) p_td(");") p_td(f"let results = (outs") with emitter_td.indent(): if operator.is_varret: p_td("Variadic:$results") else: - p_td(",\n".join([ - f"""{get_ods_type(ret["type"], is_non_value_op, is_result=True)}:${ret["name"] or generic_result_name(e)}""" - for e, ret in enumerate(operator.returns) - ])) + p_td( + ",\n".join( + [ + f"""{get_ods_type(ret["type"], is_non_value_op, is_result=True)}:${ret["name"] or generic_result_name(e)}""" + for e, ret in enumerate(operator.returns) + ] + ) + ) p_td(");") if operator.is_vararg or operator.is_varret: @@ -174,16 +192,19 @@ def raw_emit_op(operator: JitOperator, assembly_operands = "`(` $operands `)`" assembly_operand_types = "qualified(type($operands))" else: - assembly_operands = " `,` ".join("$" + arg["name"] - for arg in operator.arguments) + assembly_operands = " `,` ".join( + "$" + arg["name"] for arg in operator.arguments + ) assembly_operand_types = " `,` ".join( - f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments) + f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments + ) if operator.is_varret: assembly_result_types = "qualified(type($results))" else: assembly_result_types = " `,` ".join( f"""qualified(type(${ret["name"] or generic_result_name(e)}))""" - for e, ret in enumerate(operator.returns)) + for e, ret in enumerate(operator.returns) + ) if assembly_operand_types and assembly_result_types: maybe_arrow = " `->` " else: @@ -192,7 +213,8 @@ def raw_emit_op(operator: JitOperator, p_td(f"let assemblyFormat = {emitter_td.quote(assembly_format)};") else: p_td(f"let hasCustomAssemblyFormat = 1;") - p_td(f"""let extraClassDefinition = [{{ + p_td( + f"""let extraClassDefinition = [{{ ParseResult {cpp_class_name}::parse(OpAsmParser &parser, OperationState &result) {{ return parseDefaultTorchOp(parser, result, {len(operator.arguments)}, {len(operator.returns)}); }} @@ -200,7 +222,8 @@ def raw_emit_op(operator: JitOperator, printDefaultTorchOp(printer, *this, {len(operator.arguments)}, {len(operator.returns)}); }} }}]; -""") +""" + ) if has_folder: p_td("let hasFolder = 1;") if has_canonicalizer: @@ -211,13 +234,15 @@ def raw_emit_op(operator: JitOperator, p_td("\n") -def emit_op(operator: JitOperator, - emitter_td: TextEmitter, - *, - traits: Optional[List[str]] = None, - has_folder: bool = False, - has_canonicalizer: bool = False, - has_verifier: bool = False): +def emit_op( + operator: JitOperator, + emitter_td: TextEmitter, + *, + traits: Optional[List[str]] = None, + has_folder: bool = False, + has_canonicalizer: bool = False, + has_verifier: bool = False, +): """Main entry point for op emission. Besides emitting the op, it deduces / adds traits based on the operator @@ -233,12 +258,14 @@ def emit_op(operator: JitOperator, if operator.is_readonly(): traits += ["ReadOnly"] - raw_emit_op(operator, - emitter_td, - traits=traits, - has_folder=has_folder, - has_canonicalizer=has_canonicalizer, - has_verifier=has_verifier) + raw_emit_op( + operator, + emitter_td, + traits=traits, + has_folder=has_folder, + has_canonicalizer=has_canonicalizer, + has_verifier=has_verifier, + ) def emit_ops(emitter_td: TextEmitter, registry: Registry): @@ -253,9 +280,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): ns, unqual, overload = operator.triple # Underscore variant of functional ops should have "functional" part removed. is_functional_op = overload == "functional" - emit_op(registry.get_by_triple((ns, unqual + "_", overload if not is_functional_op else "")), - emitter_td, - traits=["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []) + emit_op( + registry.get_by_triple( + (ns, unqual + "_", overload if not is_functional_op else "") + ), + emitter_td, + traits=["IsTrailingUnderscoreInplaceVariant"] + if not is_functional_op + else [], + ) # ========================================================================== # `aten::` namespace. @@ -263,114 +296,174 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # Elementwise tensor compute ops for key in [ - "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", - "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", - "aten::relu : (Tensor) -> (Tensor)", - "aten::relu6 : (Tensor) -> (Tensor)", - "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", - "aten::selu : (Tensor) -> (Tensor)", - "aten::sigmoid : (Tensor) -> (Tensor)", - "aten::sinh : (Tensor) -> (Tensor)", - "aten::sgn : (Tensor) -> (Tensor)", - "aten::hardsigmoid : (Tensor) -> (Tensor)", - "aten::hardswish : (Tensor) -> (Tensor)", - "aten::erf : (Tensor) -> (Tensor)", - "aten::erfinv : (Tensor) -> (Tensor)", - "aten::silu : (Tensor) -> (Tensor)", - "aten::sin : (Tensor) -> (Tensor)", - "aten::asin : (Tensor) -> (Tensor)", - "aten::asinh : (Tensor) -> (Tensor)", - "aten::exp : (Tensor) -> (Tensor)", - "aten::expm1 : (Tensor) -> (Tensor)", - "aten::cos : (Tensor) -> (Tensor)", - "aten::cosh : (Tensor) -> (Tensor)", - "aten::acos : (Tensor) -> (Tensor)", - "aten::acosh : (Tensor) -> (Tensor)", - "aten::tan : (Tensor) -> (Tensor)", - "aten::tanh : (Tensor) -> (Tensor)", - "aten::atan : (Tensor) -> (Tensor)", - "aten::atanh : (Tensor) -> (Tensor)", - "aten::atan2 : (Tensor, Tensor) -> (Tensor)", - "aten::neg : (Tensor) -> (Tensor)", - "aten::bitwise_not : (Tensor) -> (Tensor)", - "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::logical_or : (Tensor, Tensor) -> (Tensor)", - "aten::logical_and : (Tensor, Tensor) -> (Tensor)", - "aten::logical_xor : (Tensor, Tensor) -> (Tensor)", - "aten::logical_not : (Tensor) -> (Tensor)", - "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", - "aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", - "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::le.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::div.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", - "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", - "aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)", - "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", - "aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", - "aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::log2 : (Tensor) -> (Tensor)", - "aten::log10 : (Tensor) -> (Tensor)", - "aten::sqrt : (Tensor) -> (Tensor)", - "aten::log1p : (Tensor) -> (Tensor)", - "aten::logit : (Tensor, float?) -> (Tensor)", - "aten::rsqrt : (Tensor) -> (Tensor)", - "aten::abs : (Tensor) -> (Tensor)", - "aten::reciprocal : (Tensor) -> (Tensor)", - "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", - "aten::square : (Tensor) -> (Tensor)", - "aten::zero : (Tensor) -> (Tensor)", - "aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)" + "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", + "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", + "aten::relu : (Tensor) -> (Tensor)", + "aten::relu6 : (Tensor) -> (Tensor)", + "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", + "aten::selu : (Tensor) -> (Tensor)", + "aten::sigmoid : (Tensor) -> (Tensor)", + "aten::sinh : (Tensor) -> (Tensor)", + "aten::sgn : (Tensor) -> (Tensor)", + "aten::hardsigmoid : (Tensor) -> (Tensor)", + "aten::hardswish : (Tensor) -> (Tensor)", + "aten::erf : (Tensor) -> (Tensor)", + "aten::erfinv : (Tensor) -> (Tensor)", + "aten::silu : (Tensor) -> (Tensor)", + "aten::sin : (Tensor) -> (Tensor)", + "aten::asin : (Tensor) -> (Tensor)", + "aten::asinh : (Tensor) -> (Tensor)", + "aten::exp : (Tensor) -> (Tensor)", + "aten::expm1 : (Tensor) -> (Tensor)", + "aten::cos : (Tensor) -> (Tensor)", + "aten::cosh : (Tensor) -> (Tensor)", + "aten::acos : (Tensor) -> (Tensor)", + "aten::acosh : (Tensor) -> (Tensor)", + "aten::tan : (Tensor) -> (Tensor)", + "aten::tanh : (Tensor) -> (Tensor)", + "aten::atan : (Tensor) -> (Tensor)", + "aten::atanh : (Tensor) -> (Tensor)", + "aten::atan2 : (Tensor, Tensor) -> (Tensor)", + "aten::neg : (Tensor) -> (Tensor)", + "aten::bitwise_not : (Tensor) -> (Tensor)", + "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::logical_or : (Tensor, Tensor) -> (Tensor)", + "aten::logical_and : (Tensor, Tensor) -> (Tensor)", + "aten::logical_xor : (Tensor, Tensor) -> (Tensor)", + "aten::logical_not : (Tensor) -> (Tensor)", + "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", + "aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", + "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::le.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::div.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", + "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", + "aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)", + "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", + "aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", + "aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::log2 : (Tensor) -> (Tensor)", + "aten::log10 : (Tensor) -> (Tensor)", + "aten::sqrt : (Tensor) -> (Tensor)", + "aten::log1p : (Tensor) -> (Tensor)", + "aten::logit : (Tensor, float?) -> (Tensor)", + "aten::rsqrt : (Tensor) -> (Tensor)", + "aten::abs : (Tensor) -> (Tensor)", + "aten::reciprocal : (Tensor) -> (Tensor)", + "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", + "aten::square : (Tensor) -> (Tensor)", + "aten::zero : (Tensor) -> (Tensor)", + "aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)", ]: emit_with_mutating_variants(key) # Shape manipulations: - emit_with_mutating_variants("aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True) + emit_with_mutating_variants( + "aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True + ) # Elementwise tensor compute ops that don't have the standard mutating # variants. - emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True) - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) - emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) - emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) - emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) - emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) - emit_with_mutating_variants("aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants( + "aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", + has_canonicalizer=True, + ) + emit_with_mutating_variants( + "aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", + has_canonicalizer=True, + ) + emit_with_mutating_variants( + "aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) + emit_with_mutating_variants( + "aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) + emit_with_mutating_variants( + "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) + emit_with_mutating_variants( + "aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) + emit_with_mutating_variants( + "aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) + emit_with_mutating_variants( + "aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) + emit_with_mutating_variants( + "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True + ) + emit_with_mutating_variants( + "aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True + ) + emit_with_mutating_variants( + "aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True + ) + emit_with_mutating_variants( + "aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True + ) + emit_with_mutating_variants( + "aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True + ) + emit_with_mutating_variants( + "aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True + ) + emit_with_mutating_variants( + "aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True + ) emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants( + "aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True + ) + emit_with_mutating_variants( + "aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", + has_canonicalizer=True, + ) - emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") - emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") - emit("aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)") + emit_with_mutating_variants( + "aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)" + ) + emit_with_mutating_variants( + "aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)" + ) + emit( + "aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)" + ) emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") emit("aten::mish : (Tensor) -> (Tensor)") emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)") - emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) + emit( + "aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + has_canonicalizer=True, + ) emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") @@ -392,7 +485,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])") # Random number generation - emit_with_mutating_variants("aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)") + emit_with_mutating_variants( + "aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)" + ) emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") @@ -400,11 +495,17 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)") emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)") - emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") + emit( + "aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)" + ) emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)") - emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)") + emit_with_mutating_variants( + "aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)" + ) emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)") + emit( + "aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)" + ) emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::random : (Tensor, Generator?) -> (Tensor)") emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)") @@ -412,10 +513,14 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") emit_with_mutating_variants( - "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") + "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)" + ) emit_with_mutating_variants( - "aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)") - emit("aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)") + "aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)" + ) + emit( + "aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)" + ) # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") @@ -433,16 +538,32 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) - emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") - emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") - emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit( + "aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)" + ) + emit( + "aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)" + ) + emit( + "aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)" + ) emit("aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)") - emit("aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)") - emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") - emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") - emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") + emit( + "aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)" + ) + emit( + "aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)" + ) + emit( + "aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)" + ) + emit( + "aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)" + ) emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"), - emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)") + emit( + "aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)" + ) emit("aten::flip : (Tensor, int[]) -> (Tensor)") emit( "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" @@ -456,43 +577,33 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) - emit( - 'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)' - ) + emit("aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)") emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) - emit( - "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" - ) + emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)") emit( "aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)", ) emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) - emit( - "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" - ) + emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) emit( "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) - emit( - "aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" - ) + emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit( "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) emit( "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) - emit( - "aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)" - ) + emit("aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)") emit( "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) @@ -505,18 +616,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) - emit( - "aten::softmax.int : (Tensor, int, int?) -> (Tensor)" + emit("aten::softmax.int : (Tensor, int, int?) -> (Tensor)") + emit("aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)") + emit("aten::_log_softmax : (Tensor, int, bool) -> (Tensor)") + emit_with_mutating_variants( + "aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)" ) - emit( - "aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)" + emit_with_mutating_variants( + "aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)" ) - emit( - "aten::_log_softmax : (Tensor, int, bool) -> (Tensor)" + emit_with_mutating_variants( + "aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)" ) - emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") - emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") - emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") @@ -540,7 +651,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") @@ -549,13 +660,23 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::var : (Tensor, bool) -> (Tensor)") emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)") - emit("aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)") + emit( + "aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)" + ) emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)") emit("aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)") - emit("aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") - emit("aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") - emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") - emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") + emit( + "aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)" + ) + emit( + "aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)" + ) + emit( + "aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)" + ) + emit( + "aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)" + ) emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)") @@ -563,17 +684,25 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") - emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)") - emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)") + emit( + "aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)" + ) + emit( + "aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)" + ) emit("aten::nonzero : (Tensor) -> (Tensor)") emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])") emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)") emit("aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)") - emit("aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)") + emit( + "aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)" + ) emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)") emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") - emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)") + emit( + "aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)" + ) emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)") # Misc tensor ops. @@ -590,9 +719,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True) - emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) + emit( + "aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True + ) emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) + emit( + "aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True + ) emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") @@ -611,8 +744,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit( + "aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)" + ) + emit( + "aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)" + ) emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") @@ -624,19 +761,29 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)") - emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)") + emit( + "aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)" + ) emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True) emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True) emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") - emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)") - emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)") + emit( + "aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)" + ) + emit( + "aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)" + ) emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)") + emit( + "aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)" + ) emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") - emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") + emit( + "aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)" + ) emit("aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") @@ -644,7 +791,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") + emit_with_mutating_variants( + "aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)" + ) emit("aten::item : (Tensor) -> (Scalar)", has_folder=True) emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True) @@ -670,9 +819,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amin : (Tensor, int[], bool) -> (Tensor)") - emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True) - emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True) - emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True) + emit( + "aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True + ) + emit( + "aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", + has_folder=True, + has_canonicalizer=True, + ) + emit( + "aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", + has_canonicalizer=True, + ) emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True) @@ -681,33 +839,64 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True) - emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) - emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)", has_folder=True) - emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True) + emit( + "aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) + emit( + "aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)", + has_folder=True, + ) + emit( + "aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True + ) emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)") - emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True) + emit( + "aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", + has_folder=True, + ) emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") - emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") - emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)") + emit_with_mutating_variants( + "aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)" + ) + emit_with_mutating_variants( + "aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)" + ) emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True) - emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True) + emit( + "aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True + ) emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::t : (Tensor) -> (Tensor)") emit("aten::numpy_T : (Tensor) -> (Tensor)") - emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) - emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") - emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") - emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") + emit( + "aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", + has_folder=True, + ) + emit( + "aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)" + ) + emit( + "aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)" + ) + emit_with_mutating_variants( + "aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)" + ) emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") - emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") - emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)") + emit( + "aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)" + ) + emit( + "aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)" + ) emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True) # Functionalization ops @@ -737,7 +926,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") - emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)") + emit( + "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)" + ) emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") # Dict ops. @@ -750,7 +941,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()") # List ops. - emit("aten::cat : (Tensor[], int) -> (Tensor)", has_canonicalizer=True, has_folder=True) + emit( + "aten::cat : (Tensor[], int) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) @@ -823,9 +1018,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__not__ : (bool) -> (bool)", has_folder=True) - emit("aten::len.t : (t[]) -> (int)", - has_folder=True, - has_canonicalizer=True) + emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True) @@ -849,12 +1042,22 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)") emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") - emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)") - emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)") - emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)") - emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)") + emit( + "aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)" + ) + emit( + "aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)" + ) + emit( + "aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)" + ) + emit( + "aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)" + ) emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)") - emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)") + emit( + "aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)" + ) emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") # quantized ops @@ -863,7 +1066,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::dequantize.self : (Tensor) -> (Tensor)") emit("aten::dequantize.tensor : (Tensor) -> (Tensor)") emit("aten::int_repr : (Tensor) -> (Tensor)") - emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)") + emit( + "aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)" + ) emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") # ========================================================================== @@ -881,10 +1086,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.int : (int, int) -> (int)", has_folder=True) emit("prim::RaiseException : (str, str?) -> ()") - emit("prim::Uninitialized : () -> (Any)", - has_canonicalizer=True, traits=["Pure"]) - emit("prim::unchecked_cast : (t) -> (t)", has_folder=True, - traits=["DeclareOpInterfaceMethods"]) + emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True, traits=["Pure"]) + emit( + "prim::unchecked_cast : (t) -> (t)", + has_folder=True, + traits=["DeclareOpInterfaceMethods"], + ) emit("prim::Print : (...) -> ()") emit("prim::tolist : (...) -> (...)") emit("prim::abs.Scalar : (Scalar) -> (Scalar)") @@ -908,13 +1115,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)", - traits=["HasValueSemantics"]) + traits=["HasValueSemantics"], + ) def dump_registered_ops(outfile: TextIO, registry: Registry): for _, v in sorted(registry.by_unique_key.items()): outfile.write(repr(v)) + def _maybe_import_op_extensions(args: argparse.Namespace): extension_string = str.strip(args.pytorch_op_extensions) if len(extension_string) > 0: @@ -924,6 +1133,7 @@ def _maybe_import_op_extensions(args: argparse.Namespace): # importing these modules, so we don't need the return value. importlib.import_module(name) + def main(args: argparse.Namespace): _maybe_import_op_extensions(args) registry = Registry.load() @@ -942,15 +1152,18 @@ def _create_argparse() -> argparse.ArgumentParser: parser.add_argument( "--torch_ir_include_dir", required=True, - help="Directory in include/ containing the Torch dialect") + help="Directory in include/ containing the Torch dialect", + ) parser.add_argument( "--debug_registry_dump", - help="File to dump the the PyTorch JIT operator registry into") + help="File to dump the the PyTorch JIT operator registry into", + ) parser.add_argument( "--pytorch_op_extensions", type=str, default="", - help="An optional, comma-separated list of Python modules which register additional PyTorch operators upon being imported. These modules can be used to build a torch-mlir which supports PyTorch extensions.") + help="An optional, comma-separated list of Python modules which register additional PyTorch operators upon being imported. These modules can be used to build a torch-mlir which supports PyTorch extensions.", + ) return parser diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/utils.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/utils.py index 0e315d705..be841b198 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/utils.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/utils.py @@ -8,8 +8,10 @@ from typing import TextIO from contextlib import contextmanager import textwrap + class TextEmitter: """Helper for emitting text files""" + _INDENT = " " def __init__(self, out: TextIO): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/torchscript_annotations.py b/projects/pt1/python/torch_mlir/jit_ir_importer/torchscript_annotations.py index a6541b650..c8f714324 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/torchscript_annotations.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/torchscript_annotations.py @@ -24,34 +24,38 @@ from torch_mlir.jit_ir_importer import ClassAnnotator # Utilities for extracting decorated information into ClassAnnotator. + def _recursively_extract_annotations( - module: torch.nn.Module, scripted: torch.jit.ScriptModule, - class_annotator: ClassAnnotator): + module: torch.nn.Module, + scripted: torch.jit.ScriptModule, + class_annotator: ClassAnnotator, +): assert module.__class__.__name__ == scripted.original_name or ( - isinstance(module, torch.jit.RecursiveScriptModule) and module is - scripted), "script module does not come from specified module" + isinstance(module, torch.jit.RecursiveScriptModule) and module is scripted + ), "script module does not come from specified module" # Extract information on methods. for method_name, scripted_method in scripted.__dict__.items(): if not isinstance(scripted_method, torch.ScriptMethod): continue method = getattr(module, method_name) - if hasattr(method, '_torch_mlir_export'): + if hasattr(method, "_torch_mlir_export"): class_annotator.exportPath(scripted._c._type(), [method_name]) - if hasattr(method, '_torch_mlir_arg_annotations'): + if hasattr(method, "_torch_mlir_arg_annotations"): class_annotator.annotateArgs( - scripted._c._type(), [method_name], - method._torch_mlir_arg_annotations) + scripted._c._type(), [method_name], method._torch_mlir_arg_annotations + ) # Recurse. for name, child in module.named_children(): scripted_child = getattr(scripted, name) - _recursively_extract_annotations(child, scripted_child, - class_annotator) + _recursively_extract_annotations(child, scripted_child, class_annotator) -def extract_annotations(program: torch.nn.Module, - scripted: torch.jit.ScriptModule, - class_annotator: ClassAnnotator): +def extract_annotations( + program: torch.nn.Module, + scripted: torch.jit.ScriptModule, + class_annotator: ClassAnnotator, +): """Populate the ClassAnnotator with annotations extracted from `program`.""" class_annotator.exportNone(scripted._c._type()) _recursively_extract_annotations(program, scripted, class_annotator) diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index 508297cfe..ef224776f 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -20,7 +20,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch_mlir.compiler_utils import ( run_pipeline_with_repro_report, OutputType, - lower_mlir_module + lower_mlir_module, ) from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library @@ -105,8 +105,7 @@ class ExampleArgs: self, for chaining. """ assert method_name not in self._example_args - self._example_args[method_name] = ExampleArgs._canonicalize_args( - example_args) + self._example_args[method_name] = ExampleArgs._canonicalize_args(example_args) return self @staticmethod @@ -129,10 +128,12 @@ class ExampleArgs: example_args = [example_args] for arg in example_args: if not isinstance(arg, (TensorPlaceholder, torch.Tensor)): - raise Exception(f"Only Tensor's, TensorPlaceholder's, or sequences of " - f"Tensor's and TensorPlaceholder's are supported as " - f"example args for method inputs. " - f"Got '{arg}'.") + raise Exception( + f"Only Tensor's, TensorPlaceholder's, or sequences of " + f"Tensor's and TensorPlaceholder's are supported as " + f"example args for method inputs. " + f"Got '{arg}'." + ) return tuple(example_args) def _get_methods(self): @@ -171,7 +172,8 @@ class ExampleArgs: # "hopefully the trace works for different inputs" # bucket of concerns. raise Exception( - "TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`") + "TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`" + ) # For any dynamic dimensions, replace them with "7" # arbitrarily. If a user is using dynamic dimensions with # tracing, they are walking on thin ice already -- assume @@ -182,7 +184,8 @@ class ExampleArgs: example_args_for_trace.append(torch.tensor(1)) else: example_args_for_trace.append( - torch.ones(*shape, dtype=arg.dtype)) + torch.ones(*shape, dtype=arg.dtype) + ) else: assert isinstance(arg, torch.Tensor) example_args_for_trace.append(arg) @@ -198,21 +201,33 @@ class ExampleArgs: # ops in the backend contract, and move these lists somewhere deeper in the # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { - OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'], - OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d','aten.adaptive_avg_pool2d', 'aten.unflatten.int'], + OutputType.TOSA: [ + "aten.flatten.using_ints", + "aten.native_layer_norm", + "aten.linear", + ], + OutputType.LINALG_ON_TENSORS: [ + "aten.flatten.using_ints", + "aten.adaptive_avg_pool1d", + "aten.adaptive_avg_pool2d", + "aten.unflatten.int", + ], OutputType.STABLEHLO: [], } -def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra_library.mlir"): +def _canon_extra_library( + extra_library, extra_library_file_name="custom_op_extra_library.mlir" +): if len(extra_library) != 0: extra_library_dict = {} for library_func in extra_library: extra_library_dict[library_func.__name__] = library_func mlir_library = generate_library(extra_library_dict) - extra_library_file = \ - os.path.join(tempfile.gettempdir(), extra_library_file_name) + extra_library_file = os.path.join( + tempfile.gettempdir(), extra_library_file_name + ) with open(extra_library_file, "w") as f: f.write(mlir_library) return extra_library_file @@ -220,16 +235,18 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra return "" -def compile(model: torch.nn.Module, - example_args: _example_args, - output_type: Union[str, "OutputType"] = OutputType.TORCH, - use_tracing: bool = False, - ignore_traced_shapes=False, - backend_legal_ops: Optional[Sequence[str]] = None, - extra_library: Iterable[Callable] = [], - verbose: bool = False, - use_make_fx: bool = False, - enable_ir_printing: bool = False): +def compile( + model: torch.nn.Module, + example_args: _example_args, + output_type: Union[str, "OutputType"] = OutputType.TORCH, + use_tracing: bool = False, + ignore_traced_shapes=False, + backend_legal_ops: Optional[Sequence[str]] = None, + extra_library: Iterable[Callable] = [], + verbose: bool = False, + use_make_fx: bool = False, + enable_ir_printing: bool = False, +): """Convert a PyTorch model to MLIR. Args: @@ -283,18 +300,18 @@ def compile(model: torch.nn.Module, # See `BACKEND_LEGAL_OPS` for more details. if backend_legal_ops is not None: if output_type != OutputType.TORCH: - raise Exception("`backend_legal_ops` is only valid with the " - "`torch` output type") + raise Exception( + "`backend_legal_ops` is only valid with the " "`torch` output type" + ) backend_legal_ops = list(sorted(set(backend_legal_ops))) else: backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) if use_make_fx: - args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"] - model = make_fx( - model, - decomposition_table=_get_decomposition_table())(*args) - + args = example_args._get_for_tracing( + use_tracing=True, ignore_traced_shapes=True + )["forward"] + model = make_fx(model, decomposition_table=_get_decomposition_table())(*args) # For FX-based models, automatically strip overloads. if isinstance(model, torch.fx.GraphModule): @@ -317,12 +334,12 @@ def compile(model: torch.nn.Module, raise Exception( f"Model does not have exported method '{method_name}', " f"requested in `example_args`. Consider adding " - f"`@torch.jit.export` to the method definition.") + f"`@torch.jit.export` to the method definition." + ) scripted = model elif use_tracing: scripted = torch.jit.trace_module( - model, - example_args._get_for_tracing(use_tracing, ignore_traced_shapes) + model, example_args._get_for_tracing(use_tracing, ignore_traced_shapes) ) else: # Make sure that all the methods that the user requested get scripted. @@ -338,8 +355,7 @@ def compile(model: torch.nn.Module, annotation = [None] # `None` is always the annotation for "self". for arg in example_args: annotation.append((arg.shape, arg.dtype, True)) - class_annotator.annotateArgs( - scripted._c._type(), [method_name], annotation) + class_annotator.annotateArgs(scripted._c._type(), [method_name], annotation) mb = ModuleBuilder() import_options = ImportOptions() @@ -350,20 +366,27 @@ def compile(model: torch.nn.Module, # Import the TorchScript module to MLIR mb.import_module(scripted._c, class_annotator, import_options) except Exception as e: - raise Exception(f""" + raise Exception( + f""" PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: ### Importer C++ Exception: {e} ### Importer Diagnostics: {sys.stderr.getvalue()} -""") from None +""" + ) from None finally: sys.stderr = original_stderr if output_type == OutputType.RAW: return mb.module - option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \ - " extra-library=" + extra_library_file_name + "}" + option_string = ( + "{backend-legal-ops=" + + ",".join(backend_legal_ops) + + " extra-library=" + + extra_library_file_name + + "}" + ) run_pipeline_with_repro_report( mb.module, f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", diff --git a/projects/pt1/python/torch_mlir_e2e_test/annotations.py b/projects/pt1/python/torch_mlir_e2e_test/annotations.py index e34b0f85d..34eed4e03 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/annotations.py +++ b/projects/pt1/python/torch_mlir_e2e_test/annotations.py @@ -22,8 +22,8 @@ import torch # Attribute names used for annotations. # These should be kept in sync with their use in # `torch_mlir/torchscript_annotations.py`. -TORCH_MLIR_EXPORT_ATTR_NAME = '_torch_mlir_export' -TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = '_torch_mlir_arg_annotations' +TORCH_MLIR_EXPORT_ATTR_NAME = "_torch_mlir_export" +TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = "_torch_mlir_arg_annotations" def export(fn): diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index e45c7b18b..204ddf616 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -40,12 +40,12 @@ def refine_result_type(_result): def jit( - prog: ExportedProgram, - func_name: str, - output_type: Union[str, "OutputType"] = OutputType.TORCH, - backend_legal_ops: Optional[Sequence[str]] = None, - extra_library=None, - verbose: bool = False, + prog: ExportedProgram, + func_name: str, + output_type: Union[str, "OutputType"] = OutputType.TORCH, + backend_legal_ops: Optional[Sequence[str]] = None, + extra_library=None, + verbose: bool = False, ): if extra_library is None: extra_library = [] @@ -55,14 +55,20 @@ def jit( output_type = OutputType.get(output_type) if backend_legal_ops is not None: if output_type != OutputType.TORCH: - raise Exception("`backend_legal_ops` is only valid with the " - "`torch` output type") + raise Exception( + "`backend_legal_ops` is only valid with the " "`torch` output type" + ) backend_legal_ops = list(sorted(set(backend_legal_ops))) else: backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) - option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) + - " extra-library=" + extra_library_file_name + "}") + option_string = ( + "{backend-legal-ops=" + + ",".join(backend_legal_ops) + + " extra-library=" + + extra_library_file_name + + "}" + ) mlir_module = fx.export_and_import(prog, func_name=func_name) assert mlir_module is not None @@ -95,9 +101,11 @@ class FxImporterTestConfig(TestConfig): result: Trace = [] for item in trace: prog = torch.export.export(artifact, tuple(item.inputs)) - module = jit(prog, - func_name=artifact.__class__.__name__, - output_type=self._output_type) + module = jit( + prog, + func_name=artifact.__class__.__name__, + output_type=self._output_type, + ) module = self._backend.compile(module) backend_module = self._backend.load(module) params = { @@ -107,10 +115,10 @@ class FxImporterTestConfig(TestConfig): params_flat, params_spec = pytree.tree_flatten(params) params_flat = list(params_flat) with torch.no_grad(): - numpy_inputs = recursively_convert_to_numpy(params_flat + - item.inputs) - outputs = getattr(backend_module, - artifact.__class__.__name__)(*numpy_inputs) + numpy_inputs = recursively_convert_to_numpy(params_flat + item.inputs) + outputs = getattr(backend_module, artifact.__class__.__name__)( + *numpy_inputs + ) output = refine_result_type(outputs) if isinstance(output, (tuple, list)): user_output = [] @@ -120,7 +128,6 @@ class FxImporterTestConfig(TestConfig): user_output.append(val) output = tuple(user_output) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py b/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py index 29842ccfc..4f2d9ec90 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py @@ -23,20 +23,19 @@ class LazyTensorCoreTestConfig(TestConfig): lazy_backend._initialize() def compile(self, program: torch.nn.Module) -> torch.nn.Module: - return program.to('lazy') + return program.to("lazy") def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] for item in trace: # We need to move all the inputs to the lazy device before running in LTC. - lazy_inputs = tree_map(to_device('lazy'), item.inputs) + lazy_inputs = tree_map(to_device("lazy"), item.inputs) output = getattr(artifact, item.symbol)(*lazy_inputs) - cpu_outputs = tree_map(to_device('cpu'), output) + cpu_outputs = tree_map(to_device("cpu"), output) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=cpu_outputs)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=cpu_outputs) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py index 8c99278b0..bbc6e73ee 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py @@ -24,6 +24,7 @@ class LinalgOnTensorsBackendTestConfig(TestConfig): This class handles all the common lowering that torch-mlir does before reaching the linalg-on-tensors abstraction level. """ + def __init__(self, backend: LinalgOnTensorsBackend): super().__init__() self.backend = backend @@ -31,12 +32,11 @@ class LinalgOnTensorsBackendTestConfig(TestConfig): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) module = torchscript.compile( - program, example_args, output_type="linalg-on-tensors") + program, example_args, output_type="linalg-on-tensors" + ) return self.backend.compile(module) - - def run(self, artifact: Any, trace: Trace) -> Trace: backend_module = self.backend.load(artifact) result: Trace = [] @@ -45,7 +45,6 @@ class LinalgOnTensorsBackendTestConfig(TestConfig): outputs = getattr(backend_module, item.symbol)(*numpy_inputs) output = recursively_convert_from_numpy(outputs) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py b/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py index 76c349bc3..e7907cd14 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py @@ -10,6 +10,7 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem class NativeTorchTestConfig(TestConfig): """TestConfig that just runs the torch.nn.Module without compiling""" + def __init__(self): super().__init__() @@ -23,7 +24,6 @@ class NativeTorchTestConfig(TestConfig): for item in trace: output = getattr(artifact, item.symbol)(*item.inputs) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index e411a7cbb..6fa845ab3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -47,7 +47,7 @@ def convert_onnx(model, inputs): input_names = [] dynamic_tensors = {} for (index, arg) in enumerate(inputs): - shape = map(lambda d : d if d >= 0 else 1, arg.shape) + shape = map(lambda d: d if d >= 0 else 1, arg.shape) shape = tuple(shape) examples.append(torch.zeros(size=shape, dtype=arg.dtype)) @@ -56,24 +56,27 @@ def convert_onnx(model, inputs): dynamic_dims = {} for (dimindex, dim) in enumerate(arg.shape): - if (dim < 0): + if dim < 0: dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) - if (dynamic_dims): + if dynamic_dims: dynamic_tensors[input_name] = dynamic_dims - - examples=tuple(examples) - torch.onnx.export(model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors) + examples = tuple(examples) + torch.onnx.export( + model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors + ) buffer = buffer.getvalue() return import_onnx(buffer) + class OnnxBackendTestConfig(TestConfig): """Base class for TestConfig's that are implemented with ONNX. This class handles all the common lowering that torch-mlir does before reaching the ONNX abstraction level. """ + def __init__(self, backend: OnnxBackend, use_make_fx: bool = False): super().__init__() self.backend = backend @@ -85,8 +88,6 @@ class OnnxBackendTestConfig(TestConfig): compiled_module = self.backend.compile(onnx_module) return compiled_module - - def run(self, artifact: Any, trace: Trace) -> Trace: backend_module = self.backend.load(artifact) result: Trace = [] @@ -95,7 +96,6 @@ class OnnxBackendTestConfig(TestConfig): outputs = getattr(backend_module, "main_graph")(*numpy_inputs) output = recursively_convert_from_numpy(outputs) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index 13f4d3df8..54dc7d3f9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -48,17 +48,17 @@ def refine_result_type(_result): def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: for node in fx_graph.graph.nodes: if node.op == "output": - assert len( - node.args) == 1, "Output node must have a single argument" + assert len(node.args) == 1, "Output node must have a single argument" node_arg = node.args[0] if node_arg != (): return False return True -# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to + +# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to # torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument -# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it -# needs to check argument types, which are not yet determined. +# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it +# needs to check argument types, which are not yet determined. # Maybe schema or target should be changed, but it decided in # _dynamo eval_frame on pytorch side. Also Python schema not matches # with mlir Schema - check include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -69,7 +69,7 @@ def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule): for node in gm.graph.nodes: # Checks if we're calling a function (i.e: # torch.add) - if node.op == 'call_function': + if node.op == "call_function": # The target attribute is the function # that call_function calls. # call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {}) @@ -84,7 +84,7 @@ def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule): elif not isinstance(node.args[1], torch.fx.node.Node): node.target = torch.ops.aten.mul.Scalar - gm.graph.lint() # Does some checks to make sure the + gm.graph.lint() # Does some checks to make sure the # Recompile the forward() method of `gm` from its Graph gm.recompile() @@ -108,20 +108,24 @@ def jit( output_type = OutputType.get(output_type) if backend_legal_ops is not None: if output_type != OutputType.TORCH: - raise Exception("`backend_legal_ops` is only valid with the " - "`torch` output type") + raise Exception( + "`backend_legal_ops` is only valid with the " "`torch` output type" + ) backend_legal_ops = list(sorted(set(backend_legal_ops))) else: backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) @make_boxed_compiler - def my_aot_autograd_backend(gm: torch.fx.GraphModule, - _example_inputs: List[torch.Tensor]): + def my_aot_autograd_backend( + gm: torch.fx.GraphModule, _example_inputs: List[torch.Tensor] + ): # Torch-MLIR does not support returning an empty tuple. The reason is # that both returning an empty tuple and returning `None` results in MLIR # functions that have as a return type `()`. In other words, there is no # way of differentiating between the two. - assert not _returns_empty_tuple(gm), "encountered graph that does not return anything" + assert not _returns_empty_tuple( + gm + ), "encountered graph that does not return anything" scalarize_tensor_ops_on_scalars(gm) @@ -130,18 +134,24 @@ def jit( mlir_module = import_fx_graph_as_func(gm.graph, model_name) return gm - my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend, - decompositions=_get_decomposition_table) + my_backend = aot_autograd( + fw_compiler=my_aot_autograd_backend, decompositions=_get_decomposition_table + ) with torch.no_grad(): set_model_name(model.__class__.__name__) torch._dynamo.reset() dynamo_f = dynamo.optimize(my_backend, nopython=True)( - lambda method, *inputs: method(*inputs)) - dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]), - *example_args) - option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) + - " extra-library=" + extra_library_file_name + "}") + lambda method, *inputs: method(*inputs) + ) + dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]), *example_args) + option_string = ( + "{backend-legal-ops=" + + ",".join(backend_legal_ops) + + " extra-library=" + + extra_library_file_name + + "}" + ) assert mlir_module is not None run_pipeline_with_repro_report( mlir_module, @@ -166,9 +176,7 @@ class TorchDynamoTestConfig(TestConfig): def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] for item in trace: - module = jit(artifact, - item.inputs, - output_type="linalg-on-tensors") + module = jit(artifact, item.inputs, output_type="linalg-on-tensors") module = self.backend.compile(module) backend_module = self.backend.load(module) params = { @@ -178,13 +186,12 @@ class TorchDynamoTestConfig(TestConfig): params_flat, params_spec = pytree.tree_flatten(params) params_flat = list(params_flat) with torch.no_grad(): - numpy_inputs = recursively_convert_to_numpy(params_flat + - item.inputs) - outputs = getattr(backend_module, - artifact.__class__.__name__)(*numpy_inputs) + numpy_inputs = recursively_convert_to_numpy(params_flat + item.inputs) + outputs = getattr(backend_module, artifact.__class__.__name__)( + *numpy_inputs + ) output = refine_result_type(outputs) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py index 9d105557c..a40e06f01 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py @@ -13,6 +13,7 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem class TorchScriptTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module through TorchScript""" + def __init__(self): super().__init__() @@ -26,11 +27,10 @@ class TorchScriptTestConfig(TestConfig): result: Trace = [] for item in trace: attr = artifact - for part in item.symbol.split('.'): + for part in item.symbol.split("."): attr = getattr(attr, part) output = attr(*item.inputs) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py index 8aa2d0e63..1b5c86bb6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -23,6 +23,7 @@ class TosaBackendTestConfig(TestConfig): This class handles all the common lowering that torch-mlir does before reaching the TOSA abstraction level. """ + def __init__(self, backend: TosaBackend, use_make_fx: bool = False): super().__init__() self.backend = backend @@ -31,12 +32,11 @@ class TosaBackendTestConfig(TestConfig): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) module = torchscript.compile( - program, example_args, output_type="tosa", use_make_fx=self.use_make_fx) + program, example_args, output_type="tosa", use_make_fx=self.use_make_fx + ) return self.backend.compile(module) - - def run(self, artifact: Any, trace: Trace) -> Trace: backend_module = self.backend.load(artifact) result: Trace = [] @@ -45,7 +45,6 @@ class TosaBackendTestConfig(TestConfig): outputs = getattr(backend_module, item.symbol)(*numpy_inputs) output = recursively_convert_from_numpy(outputs) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/utils.py b/projects/pt1/python/torch_mlir_e2e_test/configs/utils.py index c8f912f43..02cf5cca3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/utils.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/utils.py @@ -27,6 +27,7 @@ def recursively_convert_to_numpy(o: Any): return o raise Exception(f"Unexpected Python function input: {o}") + def recursively_convert_from_numpy(o: Any): if isinstance(o, np.ndarray): return torch.from_numpy(o) diff --git a/projects/pt1/python/torch_mlir_e2e_test/debug/lockstep.py b/projects/pt1/python/torch_mlir_e2e_test/debug/lockstep.py index ab42727d4..f560ceb59 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/debug/lockstep.py +++ b/projects/pt1/python/torch_mlir_e2e_test/debug/lockstep.py @@ -24,8 +24,7 @@ def _make_single_op_gm(node) -> torch.fx.GraphModule: return torch.fx.GraphModule(torch.nn.Module(), g) -def _identity_backend(gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): +def _identity_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): """A backend that just runs the given GraphModule as-is.""" return gm @@ -51,6 +50,7 @@ def _make_last_use_map(g: torch.fx.Graph) -> Dict[torch.fx.Node, List[torch.fx.N # Lifetime just ended, so this is the last use. seen.add(use) last_use_map[user].append(use) + for node in reversed(g.nodes): assert not node.kwargs, "kwargs not supported yet" torch.fx.map_arg(node.args, lambda n: process_use(node, n)) @@ -84,9 +84,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend): Returns: A backend that compares the wrapped backend to `golden_backend`. """ + def wrapper(user_backend): - def backend(gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): + def backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): # We can ignore the example_inputs since we recompile in lockstep # anyway. TorchDynamo should already have appropriate guards in # place so that this doesn't change the compilation result. @@ -96,7 +96,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend): def compiled(*args): env = {} - for placeholder, arg in zip([n for n in g.nodes if n.op == "placeholder"], args): + for placeholder, arg in zip( + [n for n in g.nodes if n.op == "placeholder"], args + ): env[placeholder] = arg # Evaluate the graph one node at a time, comparing the user and # golden backends. This code currently does not support @@ -111,7 +113,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend): continue if node.op == "output": return torch.fx.map_arg(node.args[0], lambda n: env[n]) - assert node.op == "call_function", f"call_module/call_method not supported for {node} -- perhaps call make_simple_dynamo_backend first" + assert ( + node.op == "call_function" + ), f"call_module/call_method not supported for {node} -- perhaps call make_simple_dynamo_backend first" assert not node.kwargs, "kwargs not supported yet" actual_args = torch.fx.map_arg(node.args, lambda n: env[n]) if node not in backend_artifacts: @@ -128,7 +132,8 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend): assert torch.allclose(user_result, golden_result), ( f"User result {user_result} is not close to " f"golden result {golden_result} for " - f"node {node} at {node.stack_trace}") + f"node {node} at {node.stack_trace}" + ) # Clean up any tensors that are no longer needed. # TODO: Find a way to test this. # This was tested manually by printing the number of entries @@ -136,6 +141,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend): for dead_node in last_use_map[node]: env.pop(dead_node) assert False, "not reached -- missing 'output' node" + return compiled + return backend + return wrapper diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index d3fecf54d..ee438cbbb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -30,6 +30,7 @@ import traceback import multiprocess as mp from multiprocess import set_start_method + try: set_start_method("spawn") except RuntimeError: @@ -38,9 +39,13 @@ except RuntimeError: import torch -TorchScriptValue = Union[int, float, List['TorchScriptValue'], - Dict['TorchScriptValue', - 'TorchScriptValue'], torch.Tensor] +TorchScriptValue = Union[ + int, + float, + List["TorchScriptValue"], + Dict["TorchScriptValue", "TorchScriptValue"], + torch.Tensor, +] class TraceItem(NamedTuple): @@ -91,9 +96,11 @@ def clone_torch_script_value(v: TorchScriptValue): # TODO: Figure out the root cause of the failure and fix properly. def clone_trace(trace: Trace) -> Trace: return [ - TraceItem(symbol=item.symbol, - inputs=clone_torch_script_value(item.inputs), - output=clone_torch_script_value(item.output)) + TraceItem( + symbol=item.symbol, + inputs=clone_torch_script_value(item.inputs), + output=clone_torch_script_value(item.output), + ) for item in trace ] @@ -101,7 +108,8 @@ def clone_trace(trace: Trace) -> Trace: # A type shared between the result of `TestConfig.compile` and the input # to `TestConfig.run`. Each backend will likely have a different definition of # this type. -CompiledArtifact = TypeVar('CompiledArtifact') +CompiledArtifact = TypeVar("CompiledArtifact") + class TestConfig(abc.ABC): """The interface implemented by backends to run tests. @@ -136,6 +144,7 @@ class TestConfig(abc.ABC): backend (compiler backend and runtime target) will have an arbitrarily wild and wonderful set of possible configurations that we cannot predict. """ + # This is not a frontend-lowered module, to allow various testing at the PyTorch level. # We can have a helper class LinalgOnTensorsBackendTestConfig which does that. @abc.abstractmethod @@ -202,8 +211,8 @@ class TestUtils: class Test(NamedTuple): - """A description of a test as produced by the test frontend. - """ + """A description of a test as produced by the test frontend.""" + # Stable name for error reporting. # # This name's stability is also useful for backend, which want to @@ -268,14 +277,20 @@ class _Tracer: inputs = [clone_torch_script_value(arg) for arg in args] output = self.__wrapped__(*args, **kwargs) self.__trace__.append( - TraceItem(symbol=".".join(self.__property_base_path__), - inputs=inputs, - output=output)) + TraceItem( + symbol=".".join(self.__property_base_path__), + inputs=inputs, + output=output, + ) + ) return output def __getattr__(self, name): - return _Tracer(getattr(self.__wrapped__, name), - self.__property_base_path__ + [name], self.__trace__) + return _Tracer( + getattr(self.__wrapped__, name), + self.__property_base_path__ + [name], + self.__trace__, + ) def generate_golden_trace(test: Test) -> Trace: @@ -297,40 +312,49 @@ def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: print(f"Compiling {test.unique_name}...", file=sys.stderr) compiled = config.compile(test.program_factory()) except Exception as e: - return TestResult(unique_name=test.unique_name, - compilation_error="".join( - traceback.format_exception( - type(e), e, e.__traceback__)), - runtime_error=None, - trace=None, - golden_trace=None) + return TestResult( + unique_name=test.unique_name, + compilation_error="".join( + traceback.format_exception(type(e), e, e.__traceback__) + ), + runtime_error=None, + trace=None, + golden_trace=None, + ) try: if verbose: print(f"Running {test.unique_name}...", file=sys.stderr) trace = config.run(compiled, golden_trace) except Exception as e: - return TestResult(unique_name=test.unique_name, - compilation_error=None, - runtime_error="".join( - traceback.format_exception( - type(e), e, e.__traceback__)), - trace=None, - golden_trace=None) - return TestResult(unique_name=test.unique_name, - compilation_error=None, - runtime_error=None, - trace=clone_trace(trace), - golden_trace=clone_trace(golden_trace)) + return TestResult( + unique_name=test.unique_name, + compilation_error=None, + runtime_error="".join( + traceback.format_exception(type(e), e, e.__traceback__) + ), + trace=None, + golden_trace=None, + ) + return TestResult( + unique_name=test.unique_name, + compilation_error=None, + runtime_error=None, + trace=clone_trace(trace), + golden_trace=clone_trace(golden_trace), + ) -def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]: +def run_tests( + tests: List[Test], config: TestConfig, sequential=False, verbose=False +) -> List[TestResult]: """Invoke the given `Test`'s with the provided `TestConfig`.""" num_processes = min(int(mp.cpu_count() * 0.8) + 1, len(tests)) try: env_concurrency = int(os.getenv("TORCH_MLIR_TEST_CONCURRENCY", "0")) except ValueError as e: - raise ValueError("Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: " - "Expected integer.") from e + raise ValueError( + "Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: " "Expected integer." + ) from e if env_concurrency > 0: num_processes = min(num_processes, env_concurrency) @@ -374,10 +398,11 @@ def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=F TestResult( unique_name=aborted_test_name, compilation_error=None, - runtime_error= - "Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n", + runtime_error="Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n", trace=None, - golden_trace=None) for aborted_test_name in aborted_tests + golden_trace=None, + ) + for aborted_test_name in aborted_tests ] results.extend(aborted_tests_results) results.sort(key=lambda result: result.unique_name) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/abc.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/abc.py index fb1120283..a58b437c6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/abc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/abc.py @@ -13,12 +13,12 @@ from torch_mlir.ir import Module # A type shared between the result of `LinalgOnTensorsBackend.compile` and the # input to `LinalgOnTensorsBackend.load`. Each backend will likely have a # different definition of this type. -CompiledArtifact = TypeVar('CompiledArtifact') +CompiledArtifact = TypeVar("CompiledArtifact") # A wrapper around a backend-specific loaded program representation # that uniformly translates the `x.method(...)` interface expected of # Torch modules into appropriate lower-level operations. -Invoker = TypeVar('Invoker') +Invoker = TypeVar("Invoker") class LinalgOnTensorsBackend(abc.ABC): @@ -27,6 +27,7 @@ class LinalgOnTensorsBackend(abc.ABC): Backends are recommended to raise meaningful exceptions in case of error, ideally with easy reproduction instructions. """ + @abc.abstractmethod def compile(self, module: Module) -> CompiledArtifact: """Compile the provided MLIR module into a compiled artifact. diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index cf5a276cd..a1611a1e5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -22,10 +22,20 @@ __all__ = [ def assert_arg_type_is_supported(ty): SUPPORTED = [ - np.float16, np.float32, np.float64, np.uint8, np.int8, np.int32, - np.int64, np.bool_, np.complex64, np.complex128 + np.float16, + np.float32, + np.float64, + np.uint8, + np.int8, + np.int32, + np.int64, + np.bool_, + np.complex64, + np.complex128, ] - assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}" + assert ( + ty in SUPPORTED + ), f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}" memref_type_to_np_dtype = { @@ -37,14 +47,14 @@ memref_type_to_np_dtype = { "mri32": np.int32, "mri64": np.int64, "mrc32": np.complex64, - "mrc64": np.complex128 + "mrc64": np.complex128, } elemental_type_to_ctype = { "i1": ctypes.c_bool, "i8": ctypes.c_byte, "i64": ctypes.c_int, "f32": ctypes.c_float, - "f64": ctypes.c_double + "f64": ctypes.c_double, } CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_" @@ -56,7 +66,7 @@ def get_return_funcs(module): with module.context: for func in module.body: # Returns strings of the form `"refbackend.."` so `"` is deleted. - func_name = str(func.attributes["sym_name"]).replace('"', '') + func_name = str(func.attributes["sym_name"]).replace('"', "") if func_name[:return_prefix_len] == CONSUME_RETURN_FUNC_PREFIX: return_funcs.append(func_name) @@ -79,7 +89,6 @@ def get_ctype_func(func_name): class RefBackendInvoker: - def __init__(self, module): self.ee = ExecutionEngine(module) self.result = None @@ -90,27 +99,29 @@ class RefBackendInvoker: ctype_wrapper, ret_types = get_ctype_func(ret_func) def consume_return_funcs(*args): - self.result = tuple([ - arg if type in elemental_type_to_ctype - else unranked_memref_to_numpy( - arg, memref_type_to_np_dtype[type]) - for arg, type in zip(args, ret_types) - ]) + self.result = tuple( + [ + arg + if type in elemental_type_to_ctype + else unranked_memref_to_numpy( + arg, memref_type_to_np_dtype[type] + ) + for arg, type in zip(args, ret_types) + ] + ) if len(self.result) == 1: self.result = self.result[0] - self.ee.register_runtime(ret_func, - ctype_wrapper(consume_return_funcs)) + self.ee.register_runtime(ret_func, ctype_wrapper(consume_return_funcs)) def __getattr__(self, function_name: str): - def invoke(*args): ffi_args = [] for arg in args: assert_arg_type_is_supported(arg.dtype) ffi_args.append( - ctypes.pointer( - ctypes.pointer(get_unranked_memref_descriptor(arg)))) + ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(arg))) + ) self.ee.invoke(function_name, *ffi_args) result = self.result @@ -121,67 +132,73 @@ class RefBackendInvoker: return invoke -LOWERING_PIPELINE = "builtin.module(" + ",".join([ - "func.func(refback-generalize-tensor-pad)", - "func.func(refback-generalize-tensor-concat)", - # Apply some optimizations. It would be great if MLIR had more useful - # optimizations that worked out of the box here. - # Note: When measured, this doesn't seem to actually help that much - # for the linalg-on-tensors backend. - # This is likely because if things are naturally fusable we usually already - # emit things in that form from the high level (e.g. single linalg-generic). - # Other backends are likely to benefit more. - "func.func(linalg-generalize-named-ops)", - "func.func(linalg-fuse-elementwise-ops)", - "convert-shape-to-std", - # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum - # to ensure operations on sparse tensors are lowered to loops. - "sparse-assembler{direct-out}", - "sparsification-and-bufferization", - "sparse-storage-specifier-to-llvm", - "inline", # inline sparse helper methods where useful - # Bufferize. - "func.func(scf-bufferize)", - "func.func(tm-tensor-bufferize)", - "func.func(empty-tensor-to-alloc-tensor)", - "func.func(linalg-bufferize)", - "func-bufferize", - "arith-bufferize", - "refback-mlprogram-bufferize", - "func.func(tensor-bufferize)", - "func.func(finalizing-bufferize)", - "func.func(buffer-deallocation)", - # Munge to make it ExecutionEngine compatible. - # Specifically, we rewrite calling convention boundaries to be in terms - # of unranked memref, and we rewrite the return to actually be a - # callback that consumes the return (the final munged function always - # returns void at the C level -- we get the return value by providing the - # callback). - "refback-munge-calling-conventions", - # Insert global variable and instruction sequence for getting the next - # global seed used in stateful rng. - # Lower to LLVM - "func.func(tm-tensor-to-loops)", - "func.func(refback-munge-memref-copy)", - "func.func(convert-linalg-to-loops)", - "func.func(lower-affine)", - "convert-scf-to-cf", - "func.func(refback-expand-ops-for-llvm)", - "func.func(arith-expand)", - "func.func(convert-math-to-llvm)", - # Handle some complex mlir::math ops (e.g. atan2) - "convert-math-to-libm", - "expand-strided-metadata", - "finalize-memref-to-llvm", - "lower-affine", - "convert-bufferization-to-memref", - "finalize-memref-to-llvm", - "func.func(convert-arith-to-llvm)", - "convert-func-to-llvm", - "convert-cf-to-llvm", - "convert-complex-to-llvm", - "reconcile-unrealized-casts", -]) + ")" +LOWERING_PIPELINE = ( + "builtin.module(" + + ",".join( + [ + "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", + # Apply some optimizations. It would be great if MLIR had more useful + # optimizations that worked out of the box here. + # Note: When measured, this doesn't seem to actually help that much + # for the linalg-on-tensors backend. + # This is likely because if things are naturally fusable we usually already + # emit things in that form from the high level (e.g. single linalg-generic). + # Other backends are likely to benefit more. + "func.func(linalg-generalize-named-ops)", + "func.func(linalg-fuse-elementwise-ops)", + "convert-shape-to-std", + # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum + # to ensure operations on sparse tensors are lowered to loops. + "sparse-assembler{direct-out}", + "sparsification-and-bufferization", + "sparse-storage-specifier-to-llvm", + "inline", # inline sparse helper methods where useful + # Bufferize. + "func.func(scf-bufferize)", + "func.func(tm-tensor-bufferize)", + "func.func(empty-tensor-to-alloc-tensor)", + "func.func(linalg-bufferize)", + "func-bufferize", + "arith-bufferize", + "refback-mlprogram-bufferize", + "func.func(tensor-bufferize)", + "func.func(finalizing-bufferize)", + "func.func(buffer-deallocation)", + # Munge to make it ExecutionEngine compatible. + # Specifically, we rewrite calling convention boundaries to be in terms + # of unranked memref, and we rewrite the return to actually be a + # callback that consumes the return (the final munged function always + # returns void at the C level -- we get the return value by providing the + # callback). + "refback-munge-calling-conventions", + # Insert global variable and instruction sequence for getting the next + # global seed used in stateful rng. + # Lower to LLVM + "func.func(tm-tensor-to-loops)", + "func.func(refback-munge-memref-copy)", + "func.func(convert-linalg-to-loops)", + "func.func(lower-affine)", + "convert-scf-to-cf", + "func.func(refback-expand-ops-for-llvm)", + "func.func(arith-expand)", + "func.func(convert-math-to-llvm)", + # Handle some complex mlir::math ops (e.g. atan2) + "convert-math-to-libm", + "expand-strided-metadata", + "finalize-memref-to-llvm", + "lower-affine", + "convert-bufferization-to-memref", + "finalize-memref-to-llvm", + "func.func(convert-arith-to-llvm)", + "convert-func-to-llvm", + "convert-cf-to-llvm", + "convert-complex-to-llvm", + "reconcile-unrealized-casts", + ] + ) + + ")" +) class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend): @@ -204,7 +221,8 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend): passed to `load`. """ run_pipeline_with_repro_report( - imported_module, LOWERING_PIPELINE, + imported_module, + LOWERING_PIPELINE, "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", enable_ir_printing=False, ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py index 684c08df4..7e12f8b15 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py @@ -13,12 +13,12 @@ from torch_mlir.ir import Module # A type shared between the result of `OnnxBackend.compile` and the # input to `OnnxBackend.load`. Each backend will likely have a # different definition of this type. -CompiledArtifact = TypeVar('CompiledArtifact') +CompiledArtifact = TypeVar("CompiledArtifact") # A wrapper around a backend-specific loaded program representation # that uniformly translates the `x.method(...)` interface expected of # Torch modules into appropriate lower-level operations. -Invoker = TypeVar('Invoker') +Invoker = TypeVar("Invoker") class OnnxBackend(abc.ABC): @@ -27,6 +27,7 @@ class OnnxBackend(abc.ABC): Backends are recommended to raise meaningful exceptions in case of error, ideally with easy reproduction instructions. """ + @abc.abstractmethod def compile(self, module: Module) -> CompiledArtifact: """Compile the provided MLIR module into a compiled artifact. diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py index fcd1efb3f..30129c751 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py @@ -12,7 +12,9 @@ from torch_mlir.compiler_utils import ( from torch_mlir.ir import * from torch_mlir.passmanager import * -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) from .abc import OnnxBackend @@ -22,9 +24,11 @@ __all__ = [ # The pipeline of func.func passes that lower the ONNX backend contract to the # Linalg-on-Tensors backend contract accepted by RefBackend. -ONNX_TO_TORCH_FUNC_PIPELINE = ",".join([ - "convert-torch-onnx-to-torch", -]) +ONNX_TO_TORCH_FUNC_PIPELINE = ",".join( + [ + "convert-torch-onnx-to-torch", + ] +) class LinalgOnTensorsOnnxBackend(OnnxBackend): @@ -50,9 +54,14 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend): run_pipeline_with_repro_report( imported_module, f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", - "Lowering Onnx backend contract to Linalg-on-Tensors backend contract") + "Lowering Onnx backend contract to Linalg-on-Tensors backend contract", + ) - backend_legal_ops = ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'] + backend_legal_ops = [ + "aten.flatten.using_ints", + "aten.adaptive_avg_pool1d", + "aten.unflatten.int", + ] option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" run_pipeline_with_repro_report( imported_module, @@ -60,7 +69,9 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend): "Lowering TorchFX IR -> Torch Backend IR", ) - imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module) + imported_module = lower_mlir_module( + False, OutputType.LINALG_ON_TENSORS, imported_module + ) compiled_module = self.refbackend.compile(imported_module) return compiled_module diff --git a/projects/pt1/python/torch_mlir_e2e_test/registry.py b/projects/pt1/python/torch_mlir_e2e_test/registry.py index 2f6cab581..d2116bafe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/registry.py +++ b/projects/pt1/python/torch_mlir_e2e_test/registry.py @@ -23,18 +23,23 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]): test's `program_factory` is taken from `module_factory`, and the `program_invoker` is the decorated function. """ + def decorator(f): # Ensure that there are no duplicate names in the global test registry. if f.__name__ in _SEEN_UNIQUE_NAMES: raise Exception( - f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name.") + f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name." + ) _SEEN_UNIQUE_NAMES.add(f.__name__) # Store the test in the registry. GLOBAL_TEST_REGISTRY.append( - Test(unique_name=f.__name__, - program_factory=module_factory, - program_invoker=f)) + Test( + unique_name=f.__name__, + program_factory=module_factory, + program_invoker=f, + ) + ) return f return decorator diff --git a/projects/pt1/python/torch_mlir_e2e_test/reporting.py b/projects/pt1/python/torch_mlir_e2e_test/reporting.py index 61fc7caf9..a59fc7ee8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/reporting.py +++ b/projects/pt1/python/torch_mlir_e2e_test/reporting.py @@ -19,6 +19,7 @@ from .framework import TestResult, TraceItem class TensorSummary: """A summary of a tensor's contents.""" + def __init__(self, tensor): self.min = torch.min(tensor.type(torch.float64)) self.max = torch.max(tensor.type(torch.float64)) @@ -27,7 +28,7 @@ class TensorSummary: self.dtype = tensor.dtype def __str__(self): - return f'Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}' + return f"Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}" class ErrorContext: @@ -35,6 +36,7 @@ class ErrorContext: This is useful for tracking errors across multiple levels of detail. """ + def __init__(self, contexts: List[str]): self.contexts = contexts @@ -47,17 +49,16 @@ class ErrorContext: return ErrorContext([]) def chain(self, additional_context: str): - """Chain an additional context onto the current error context. - """ + """Chain an additional context onto the current error context.""" return ErrorContext(self.contexts + [additional_context]) def format_error(self, s: str): - return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s + return "@ " + "\n@ ".join(self.contexts) + "\n" + "ERROR: " + s class ValueReport: - """A report for a single value processed by the program. - """ + """A report for a single value processed by the program.""" + def __init__(self, value, golden_value, context: ErrorContext): self.value = value self.golden_value = golden_value @@ -70,7 +71,7 @@ class ValueReport: return len(self.failure_reasons) != 0 def error_str(self): - return '\n'.join(self.failure_reasons) + return "\n".join(self.failure_reasons) def _evaluate_outcome(self): value, golden = self.value, self.golden_value @@ -80,37 +81,37 @@ class ValueReport: golden = golden[0] if isinstance(golden, float): if not isinstance(value, float): - return self._record_mismatch_type_failure('float', value) + return self._record_mismatch_type_failure("float", value) if abs(value - golden) / golden > 1e-4: return self._record_failure( - f'value ({value!r}) is not close to golden value ({golden!r})' + f"value ({value!r}) is not close to golden value ({golden!r})" ) return if isinstance(golden, int): if not isinstance(value, int): - return self._record_mismatch_type_failure('int', value) + return self._record_mismatch_type_failure("int", value) if value != golden: return self._record_failure( - f'value ({value!r}) is not equal to golden value ({golden!r})' + f"value ({value!r}) is not equal to golden value ({golden!r})" ) return if isinstance(golden, str): if not isinstance(value, str): - return self._record_mismatch_type_failure('str', value) + return self._record_mismatch_type_failure("str", value) if value != golden: return self._record_failure( - f'value ({value!r}) is not equal to golden value ({golden!r})' + f"value ({value!r}) is not equal to golden value ({golden!r})" ) return if isinstance(golden, tuple): if not isinstance(value, tuple): - return self._record_mismatch_type_failure('tuple', value) + return self._record_mismatch_type_failure("tuple", value) if len(value) != len(golden): return self._record_failure( - f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})' + f"value ({len(value)!r}) is not equal to golden value ({len(golden)!r})" ) reports = [ - ValueReport(v, g, self.context.chain(f'tuple element {i}')) + ValueReport(v, g, self.context.chain(f"tuple element {i}")) for i, (v, g) in enumerate(zip(value, golden)) ] for report in reports: @@ -119,13 +120,13 @@ class ValueReport: return if isinstance(golden, list): if not isinstance(value, list): - return self._record_mismatch_type_failure('list', value) + return self._record_mismatch_type_failure("list", value) if len(value) != len(golden): return self._record_failure( - f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})' + f"value ({len(value)!r}) is not equal to golden value ({len(golden)!r})" ) reports = [ - ValueReport(v, g, self.context.chain(f'list element {i}')) + ValueReport(v, g, self.context.chain(f"list element {i}")) for i, (v, g) in enumerate(zip(value, golden)) ] for report in reports: @@ -134,16 +135,19 @@ class ValueReport: return if isinstance(golden, dict): if not isinstance(value, dict): - return self._record_mismatch_type_failure('dict', value) + return self._record_mismatch_type_failure("dict", value) gkeys = list(sorted(golden.keys())) vkeys = list(sorted(value.keys())) if gkeys != vkeys: return self._record_failure( - f'dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})' + f"dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})" ) reports = [ - ValueReport(value[k], golden[k], - self.context.chain(f'dict element at key {k!r}')) + ValueReport( + value[k], + golden[k], + self.context.chain(f"dict element at key {k!r}"), + ) for k in gkeys ] for report in reports: @@ -152,40 +156,42 @@ class ValueReport: return if isinstance(golden, torch.Tensor): if not isinstance(value, torch.Tensor): - return self._record_mismatch_type_failure('torch.Tensor', value) + return self._record_mismatch_type_failure("torch.Tensor", value) if value.shape != golden.shape: return self._record_failure( - f'shape ({value.shape}) is not equal to golden shape ({golden.shape})' + f"shape ({value.shape}) is not equal to golden shape ({golden.shape})" ) if value.dtype != golden.dtype: return self._record_failure( - f'dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})' + f"dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})" ) - if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True): + if not torch.allclose( + value, golden, rtol=1e-03, atol=1e-07, equal_nan=True + ): return self._record_failure( - f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})' + f"value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})" ) return return self._record_failure( - f'unexpected golden value of type `{golden.__class__.__name__}`') + f"unexpected golden value of type `{golden.__class__.__name__}`" + ) def _record_failure(self, s: str): self.failure_reasons.append(self.context.format_error(s)) def _record_mismatch_type_failure(self, expected: str, actual: Any): self._record_failure( - f'expected a value of type `{expected}` but got `{actual.__class__.__name__}`' + f"expected a value of type `{expected}` but got `{actual.__class__.__name__}`" ) - class TraceItemReport: """A report for a single trace item.""" + failure_reasons: List[str] - def __init__(self, item: TraceItem, golden_item: TraceItem, - context: ErrorContext): + def __init__(self, item: TraceItem, golden_item: TraceItem, context: ErrorContext): self.item = item self.golden_item = golden_item self.context = context @@ -197,36 +203,43 @@ class TraceItemReport: return len(self.failure_reasons) != 0 def error_str(self): - return '\n'.join(self.failure_reasons) + return "\n".join(self.failure_reasons) def _evaluate_outcome(self): if self.item.symbol != self.golden_item.symbol: self.failure_reasons.append( self.context.format_error( f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"' - )) + ) + ) if len(self.item.inputs) != len(self.golden_item.inputs): self.failure_reasons.append( self.context.format_error( f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"' - )) + ) + ) for i, (input, golden_input) in enumerate( - zip(self.item.inputs, self.golden_item.inputs)): + zip(self.item.inputs, self.golden_item.inputs) + ): value_report = ValueReport( - input, golden_input, - self.context.chain( - f'input #{i} of call to "{self.item.symbol}"')) + input, + golden_input, + self.context.chain(f'input #{i} of call to "{self.item.symbol}"'), + ) if value_report.failed: self.failure_reasons.append(value_report.error_str()) value_report = ValueReport( - self.item.output, self.golden_item.output, - self.context.chain(f'output of call to "{self.item.symbol}"')) + self.item.output, + self.golden_item.output, + self.context.chain(f'output of call to "{self.item.symbol}"'), + ) if value_report.failed: self.failure_reasons.append(value_report.error_str()) class SingleTestReport: """A report for a single test.""" + item_reports: Optional[List[TraceItemReport]] def __init__(self, result: TestResult, context: ErrorContext): @@ -236,12 +249,15 @@ class SingleTestReport: if result.compilation_error is None and result.runtime_error is None: self.item_reports = [] for i, (item, golden_item) in enumerate( - zip(result.trace, result.golden_trace)): + zip(result.trace, result.golden_trace) + ): self.item_reports.append( TraceItemReport( - item, golden_item, - context.chain( - f'trace item #{i} - call to "{item.symbol}"'))) + item, + golden_item, + context.chain(f'trace item #{i} - call to "{item.symbol}"'), + ) + ) @property def failed(self): @@ -256,19 +272,21 @@ class SingleTestReport: f = io.StringIO() p = lambda *x: print(*x, file=f) if self.result.compilation_error is not None: - return 'Compilation error: ' + self.result.compilation_error + return "Compilation error: " + self.result.compilation_error elif self.result.runtime_error is not None: - return 'Runtime error: ' + self.result.runtime_error + return "Runtime error: " + self.result.runtime_error for report in self.item_reports: if report.failed: p(report.error_str()) return f.getvalue() -def report_results(results: List[TestResult], - expected_failures: Set[str], - verbose: bool = False, - config: str = ""): +def report_results( + results: List[TestResult], + expected_failures: Set[str], + verbose: bool = False, + config: str = "", +): """Print a basic error report summarizing various TestResult's. This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's @@ -293,49 +311,50 @@ def report_results(results: List[TestResult], if expected_failure: if report.failed: print(f'XFAIL - "{result.unique_name}"') - results_by_outcome['XFAIL'].append((result, report)) + results_by_outcome["XFAIL"].append((result, report)) else: print(f'XPASS - "{result.unique_name}"') - results_by_outcome['XPASS'].append((result, report)) + results_by_outcome["XPASS"].append((result, report)) else: if not report.failed: print(f'PASS - "{result.unique_name}"') - results_by_outcome['PASS'].append((result, report)) + results_by_outcome["PASS"].append((result, report)) else: print(f'FAIL - "{result.unique_name}"') - results_by_outcome['FAIL'].append((result, report)) + results_by_outcome["FAIL"].append((result, report)) OUTCOME_MEANINGS = collections.OrderedDict() - OUTCOME_MEANINGS['PASS'] = 'Passed' - OUTCOME_MEANINGS['FAIL'] = 'Failed' - OUTCOME_MEANINGS['XFAIL'] = 'Expectedly Failed' - OUTCOME_MEANINGS['XPASS'] = 'Unexpectedly Passed' + OUTCOME_MEANINGS["PASS"] = "Passed" + OUTCOME_MEANINGS["FAIL"] = "Failed" + OUTCOME_MEANINGS["XFAIL"] = "Expectedly Failed" + OUTCOME_MEANINGS["XPASS"] = "Unexpectedly Passed" - had_unexpected_results = len(results_by_outcome['FAIL']) != 0 or len( - results_by_outcome['XPASS']) != 0 + had_unexpected_results = ( + len(results_by_outcome["FAIL"]) != 0 or len(results_by_outcome["XPASS"]) != 0 + ) if had_unexpected_results: - print(f'\nUnexpected outcome summary: ({config})') + print(f"\nUnexpected outcome summary: ({config})") # For FAIL and XPASS (unexpected outcomes), print a summary. for outcome, results in results_by_outcome.items(): # PASS and XFAIL are "good"/"successful" outcomes. - if outcome == 'PASS' or outcome == 'XFAIL': + if outcome == "PASS" or outcome == "XFAIL": continue # If there is nothing to report, be quiet. if len(results) == 0: continue - print(f'\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests') + print(f"\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests") for result, report in results: print(f' {outcome} - "{result.unique_name}"') # If the test failed, print the error message. - if outcome == 'FAIL' and verbose: - print(textwrap.indent(report.error_str(), ' ' * 8)) + if outcome == "FAIL" and verbose: + print(textwrap.indent(report.error_str(), " " * 8)) # Print a summary for easy scanning. - print('\nSummary:') + print("\nSummary:") - for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']: + for key in ["PASS", "FAIL", "XFAIL", "XPASS"]: if results_by_outcome[key]: - print(f' {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}') + print(f" {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}") return had_unexpected_results diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index d9627a352..61050de8f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -7,7 +7,9 @@ from torch_mlir.ir import * from torch_mlir.passmanager import * from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) from .abc import StablehloBackend @@ -17,11 +19,13 @@ __all__ = [ # The pipeline of func.func passes that lower the STABLEHLO backend contract to the # Linalg-on-Tensors backend contract accepted by RefBackend. -STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ - "func.func(stablehlo-aggressive-simplification)", - "stablehlo-legalize-to-linalg", - "canonicalize" -]) +STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join( + [ + "func.func(stablehlo-aggressive-simplification)", + "stablehlo-legalize-to-linalg", + "canonicalize", + ] +) class LinalgOnTensorsStablehloBackend(StablehloBackend): @@ -47,7 +51,8 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend): run_pipeline_with_repro_report( imported_module, f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})", - "Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract") + "Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract", + ) return self.refbackend.compile(imported_module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 9bdece3fc..dca86870f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -19,6 +19,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "ElementwiseToDtypeI64ToUI8Module_basic", } + def register_all_tests(): """Registers all the built-in E2E tests that Torch-MLIR provides.""" # Side-effecting import statements. diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py index 948901307..b9dfc491f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py @@ -17,13 +17,15 @@ class ArangeIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(5) + @register_test_case(module_factory=lambda: ArangeIntModule()) def ArangeIntModule_basic(module, tu: TestUtils): module.forward() @@ -34,13 +36,15 @@ class ArangeFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(5.0) + @register_test_case(module_factory=lambda: ArangeFloatModule()) def ArangeFloatModule_basic(module, tu: TestUtils): module.forward() @@ -51,31 +55,37 @@ class ArangeZeroElementOutputModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(0) + @register_test_case(module_factory=lambda: ArangeZeroElementOutputModule()) def ArangeZeroElementOutputModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class ArangeStartIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(0, 5) + @register_test_case(module_factory=lambda: ArangeStartIntModule()) def ArangeStartIntModule_basic(module, tu: TestUtils): module.forward() @@ -86,13 +96,15 @@ class ArangeStartFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(0.0, 5.0) + @register_test_case(module_factory=lambda: ArangeStartFloatModule()) def ArangeStartFloatModule_basic(module, tu: TestUtils): module.forward() @@ -103,13 +115,15 @@ class ArangeNegativeStartIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(-10, 5) + @register_test_case(module_factory=lambda: ArangeNegativeStartIntModule()) def ArangeNegativeStartIntModule_basic(module, tu: TestUtils): module.forward() @@ -120,31 +134,37 @@ class ArangeNegativeStartFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(-1.4, 5.7) + @register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule()) def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class ArangeStartStepIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(0, 5, 1) + @register_test_case(module_factory=lambda: ArangeStartStepIntModule()) def ArangeStartStepIntModule_basic(module, tu: TestUtils): module.forward() @@ -155,13 +175,15 @@ class ArangeStartStepFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(-1, 5, 1.3) + @register_test_case(module_factory=lambda: ArangeStartStepFloatModule()) def ArangeStartStepFloatModule_basic(module, tu: TestUtils): module.forward() @@ -172,13 +194,15 @@ class ArangeStartNegativeStepIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(10, 1, -2) + @register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule()) def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils): module.forward() @@ -189,31 +213,37 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(-1, -15, -3.4) + @register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule()) def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class ArangeDtypeFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(-1, 15, dtype=torch.float32) + @register_test_case(module_factory=lambda: ArangeDtypeFloatModule()) def ArangeDtypeFloatModule_basic(module, tu: TestUtils): module.forward() @@ -224,110 +254,137 @@ class ArangeDtypeIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(0.2, 5.0, dtype=torch.int64) + @register_test_case(module_factory=lambda: ArangeDtypeIntModule()) def ArangeDtypeIntModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class ArangeFalsePinMemoryModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) - + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.arange(5.0, dtype=torch.int64, pin_memory=False) + @register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule()) def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class ArangeStartOutModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([12], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([12], torch.int64, True), + ] + ) def forward(self, x): return torch.arange(start=0, end=12, out=x) + @register_test_case(module_factory=lambda: ArangeStartOutModule()) def ArangeStartOutModule_basic(module, tu: TestUtils): module.forward(torch.zeros(12).to(torch.int64)) + class ArangeStartOutViewModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int64, True), + ] + ) def forward(self, x): return torch.arange(start=1, end=13, out=x) + @register_test_case(module_factory=lambda: ArangeStartOutViewModule()) def ArangeStartOutViewModule_basic(module, tu: TestUtils): module.forward(torch.zeros(3, 4).to(torch.int64)) + class ArangeStartOutDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([12], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([12], torch.int64, True), + ] + ) def forward(self, x): return torch.arange(start=1.1, end=13.1, out=x) + @register_test_case(module_factory=lambda: ArangeStartOutDtypeModule()) def ArangeStartOutDtypeModule_basic(module, tu: TestUtils): module.forward(torch.zeros(12).to(torch.int64)) + # ============================================================================== - + + class LinspaceModule(torch.nn.Module): def __init__(self): super().__init__() - + @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.linspace(-10.1, 10.1, 10) + @register_test_case(module_factory=lambda: LinspaceModule()) def LinspaceModule_basic(module, tu: TestUtils): module.forward() + class LinspaceDtypeModule(torch.nn.Module): def __init__(self): super().__init__() - + @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64) @@ -336,47 +393,59 @@ class LinspaceDtypeModule(torch.nn.Module): def LinspaceDtypeModule_basic(module, tu: TestUtils): module.forward() + class LinspaceEmptyModule(torch.nn.Module): def __init__(self): super().__init__() - + @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.linspace(-10.1, 10.1, 0) + @register_test_case(module_factory=lambda: LinspaceEmptyModule()) def LinspaceEmptyModule_basic(module, tu: TestUtils): module.forward() + class LinspaceOneSizeModule(torch.nn.Module): def __init__(self): super().__init__() - + @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.linspace(-10.1, 10.1, 1) + @register_test_case(module_factory=lambda: LinspaceOneSizeModule()) def LinspaceOneSizeModule_basic(module, tu: TestUtils): module.forward() + class LinspaceTwoSizeModule(torch.nn.Module): def __init__(self): super().__init__() - + @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.linspace(-10.1, 10.1, 2) + @register_test_case(module_factory=lambda: LinspaceTwoSizeModule()) def LinspaceTwoSizeModule_basic(module, tu: TestUtils): module.forward() @@ -387,12 +456,16 @@ class PrimsIotaModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): - return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu', - requires_grad=False) + return torch.ops.prims.iota( + 77, start=0, step=1, dtype=torch.int64, device="cpu", requires_grad=False + ) + @register_test_case(module_factory=lambda: PrimsIotaModule()) def PrimsIotaModule_basic(module, tu: TestUtils): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index 7caa8a4c1..e209d15b2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -13,21 +13,21 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class SoftmaxBackwardModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, grad_output, output): - return torch.ops.aten._softmax_backward_data(grad_output, - output, - dim=1, - input_dtype=6) + return torch.ops.aten._softmax_backward_data( + grad_output, output, dim=1, input_dtype=6 + ) @register_test_case(module_factory=lambda: SoftmaxBackwardModule()) @@ -37,16 +37,17 @@ def SoftmaxBackwardModule_basic(module, tu: TestUtils): # ============================================================================== class TanhBackwardModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, grad_out, output): return torch.ops.aten.tanh_backward(grad_out, output) @@ -58,40 +59,46 @@ def TanhBackward_basic(module, tu: TestUtils): # ============================================================================== -class HardtanhBackwardModule(torch.nn.Module): +class HardtanhBackwardModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, grad_out, input): - return torch.ops.aten.hardtanh_backward(grad_out, input, min_val=0.2, max_val=0.5) + return torch.ops.aten.hardtanh_backward( + grad_out, input, min_val=0.2, max_val=0.5 + ) @register_test_case(module_factory=lambda: HardtanhBackwardModule()) def HardtanhBackward_basic(module, tu: TestUtils): module.forward(tu.rand(10, 20), tu.rand(10, 20)) + # ============================================================================== class ConvolutionBackwardModule2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, grad_out, input_vec, weight): return torch.ops.aten.convolution_backward( grad_out, @@ -104,27 +111,29 @@ class ConvolutionBackwardModule2D(torch.nn.Module): transposed=False, output_padding=[0], groups=1, - output_mask=[True, True, True]) + output_mask=[True, True, True], + ) @register_test_case(module_factory=lambda: ConvolutionBackwardModule2D()) def ConvolutionBackwardModule2D_basic(module, tu: TestUtils): with torch.backends.mkldnn.flags(enabled=False): - module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6), - tu.rand(2, 2, 2, 2)) + module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6), tu.rand(2, 2, 2, 2)) + class ConvolutionBackwardModule2DStatic(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 4, 64, 64], torch.float32, True), - ([1, 320, 64, 64], torch.float32, True), - ([4, 320, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 4, 64, 64], torch.float32, True), + ([1, 320, 64, 64], torch.float32, True), + ([4, 320, 3, 3], torch.float32, True), + ] + ) def forward(self, grad_out, input_vec, weight): return torch.ops.aten.convolution_backward( grad_out, @@ -137,28 +146,31 @@ class ConvolutionBackwardModule2DStatic(torch.nn.Module): transposed=False, output_padding=[0, 0], groups=1, - output_mask=[True, True, True]) + output_mask=[True, True, True], + ) @register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStatic()) def ConvolutionBackwardModule2DStatic_basic(module, tu: TestUtils): with torch.backends.mkldnn.flags(enabled=False): - module.forward(tu.rand(1, 4, 64, 64), tu.rand(1, 320, 64, 64), - tu.rand(4, 320, 3, 3)) + module.forward( + tu.rand(1, 4, 64, 64), tu.rand(1, 320, 64, 64), tu.rand(4, 320, 3, 3) + ) class ConvolutionBackwardModule2DPadded(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, grad_out, input_vec, weight): return torch.ops.aten.convolution_backward( grad_out, @@ -171,28 +183,29 @@ class ConvolutionBackwardModule2DPadded(torch.nn.Module): transposed=False, output_padding=[0], groups=1, - output_mask=[True, True, True]) + output_mask=[True, True, True], + ) @register_test_case(module_factory=lambda: ConvolutionBackwardModule2DPadded()) def ConvolutionBackwardModule2DPadded_basic(module, tu: TestUtils): with torch.backends.mkldnn.flags(enabled=False): - module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6), - tu.rand(2, 2, 3, 3)) + module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6), tu.rand(2, 2, 3, 3)) class ConvolutionBackwardModule2DStrided(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 2, 4, 4], torch.float32, True), - ([1, 2, 8, 8], torch.float32, True), - ([2, 2, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 2, 4, 4], torch.float32, True), + ([1, 2, 8, 8], torch.float32, True), + ([2, 2, 3, 3], torch.float32, True), + ] + ) def forward(self, grad_out, input_vec, weight): return torch.ops.aten.convolution_backward( grad_out, @@ -205,30 +218,31 @@ class ConvolutionBackwardModule2DStrided(torch.nn.Module): transposed=False, output_padding=[0, 0], groups=1, - output_mask=[True, True, True]) + output_mask=[True, True, True], + ) @register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStrided()) def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils): with torch.backends.mkldnn.flags(enabled=False): - module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), - tu.rand(2, 2, 3, 3)) + module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), tu.rand(2, 2, 3, 3)) # ============================================================================== class GeluBackwardModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.gelu_backward(grad, input) @@ -239,21 +253,21 @@ def GeluBackwardModule_basic(module, tu: TestUtils): class LogSoftmaxBackwardModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, grad_output, output): - return torch.ops.aten._log_softmax_backward_data(grad_output, - output, - dim=1, - input_dtype=6) + return torch.ops.aten._log_softmax_backward_data( + grad_output, output, dim=1, input_dtype=6 + ) @register_test_case(module_factory=lambda: LogSoftmaxBackwardModule()) @@ -265,18 +279,21 @@ def LogSoftmaxBackwardModule_basic(module, tu: TestUtils): class LeakyReluBackwardModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, grad, input): - return torch.ops.aten.leaky_relu_backward(grad, input, negative_slope=0.1, self_is_result=False) + return torch.ops.aten.leaky_relu_backward( + grad, input, negative_slope=0.1, self_is_result=False + ) @register_test_case(module_factory=lambda: LeakyReluBackwardModule()) @@ -285,18 +302,21 @@ def LeakyReluBackwardModule_basic(module, tu: TestUtils): class LeakyReluBackwardStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float32, True), - ([3, 4, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) def forward(self, grad, input): - return torch.ops.aten.leaky_relu_backward(grad, input, negative_slope=0.1, self_is_result=False) + return torch.ops.aten.leaky_relu_backward( + grad, input, negative_slope=0.1, self_is_result=False + ) @register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule()) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index fa99522a8..b483f9d3c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -12,36 +12,42 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== -class ScalarConstantTupleModule(torch.nn.Module): +class ScalarConstantTupleModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return (1, 2) + @register_test_case(module_factory=lambda: ScalarConstantTupleModule()) def ScalarConstantTupleModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 4)) + # ============================================================================== -class MmModule(torch.nn.Module): +class MmModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.mm(lhs, rhs) @@ -61,16 +67,17 @@ def MmModule_chained(module, tu: TestUtils): class BmmFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.bmm(lhs, rhs) @@ -81,16 +88,17 @@ def BmmFloatModule_basic(module, tu: TestUtils): class BmmIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return torch.bmm(lhs, rhs) @@ -104,15 +112,16 @@ def BmmIntModule_basic(module, tu: TestUtils): class IsFloatingPointInt(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.is_floating_point(x) @@ -126,15 +135,16 @@ def IsFloatingPointInt_False(module, tu: TestUtils): class IsFloatingPointFloat(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, x): return torch.is_floating_point(x) @@ -150,12 +160,13 @@ def IsFloatingPointFloat_True(module, tu: TestUtils): class ContainsIntList(torch.nn.Module): def __init__(self): super().__init__() + @export - @annotate_args([ - None - ]) + @annotate_args([None]) def forward(self): - return torch.ops.aten.__contains__([1,2,3], 3) + return torch.ops.aten.__contains__([1, 2, 3], 3) + + @register_test_case(module_factory=lambda: ContainsIntList()) def ContainsIntList_True(module, tu: TestUtils): module.forward() @@ -167,12 +178,13 @@ def ContainsIntList_True(module, tu: TestUtils): class ContainsIntListFalse(torch.nn.Module): def __init__(self): super().__init__() + @export - @annotate_args([ - None - ]) + @annotate_args([None]) def forward(self): - return torch.ops.aten.__contains__([1,2,3], 4) + return torch.ops.aten.__contains__([1, 2, 3], 4) + + @register_test_case(module_factory=lambda: ContainsIntListFalse()) def ContainsIntList_False(module, tu: TestUtils): module.forward() @@ -183,16 +195,17 @@ def ContainsIntList_False(module, tu: TestUtils): # A subgraph with multiple mm ops. class MmDagModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 4], torch.float32, True), - ([4, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 4], torch.float32, True), + ([4, 4], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.mm(lhs, torch.mm(lhs, rhs)) @@ -206,16 +219,17 @@ def MmDagModule_basic(module, tu: TestUtils): class MmTanhModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.tanh(self.matmul(lhs, rhs)) @@ -232,17 +246,18 @@ def MmTanhModule_basic(module, tu: TestUtils): class AddmmModuleFloat(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, M, mat1, mat2): return torch.addmm(M, mat1, mat2, beta=3.0, alpha=7.0) @@ -256,17 +271,18 @@ def AddmmModuleFloat_basic(module, tu: TestUtils): class AddmmModuleBroadcastable(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, M, mat1, mat2): return torch.addmm(M, mat1, mat2, beta=2.0, alpha=7.0) @@ -280,23 +296,23 @@ def AddmmModule_broadcastable(module, tu: TestUtils): class AddmmModuleDifferentRankBroadcastable(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, M, mat1, mat2): return torch.addmm(M, mat1, mat2, beta=11.0, alpha=7.0) -@register_test_case( - module_factory=lambda: AddmmModuleDifferentRankBroadcastable()) +@register_test_case(module_factory=lambda: AddmmModuleDifferentRankBroadcastable()) def AddmmModule_differentRankBroadcastable(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(3, 2), tu.rand(2, 3)) @@ -305,15 +321,16 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils): class UnflattenStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 6, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 6, 4], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.unflatten(x, 1, (2, 3)) @@ -327,16 +344,17 @@ def UnflattenStaticModule_basic(module, tu: TestUtils): class FlattenStaticModule(torch.nn.Module): - def __init__(self): super().__init__() self.flat = torch.nn.Flatten(2, 4) @export - @annotate_args([ - None, - ([10, 3, 8, 9, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([10, 3, 8, 9, 3, 4], torch.float32, True), + ] + ) def forward(self, x): return self.flat(x) @@ -350,16 +368,17 @@ def FlattenStaticModule_basic(module, tu: TestUtils): class FlattenRank0Module(torch.nn.Module): - def __init__(self): super().__init__() self.flat = torch.nn.Flatten(-1, -1) @export - @annotate_args([ - None, - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) def forward(self, x): return self.flat(x) @@ -373,16 +392,17 @@ def FlattenRank0Module_basic(module, tu: TestUtils): class FlattenDynamicModule(torch.nn.Module): - def __init__(self): super().__init__() self.flat = torch.nn.Flatten(2, 4) @export - @annotate_args([ - None, - ([-1, -1, -1, 9, 3, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, 9, 3, -1], torch.float32, True), + ] + ) def forward(self, x): return self.flat(x) @@ -391,17 +411,19 @@ class FlattenDynamicModule(torch.nn.Module): def FlattenDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 3, 8, 9, 3, 4)) -class FlattenDynamicModuleCollapseAll(torch.nn.Module): +class FlattenDynamicModuleCollapseAll(torch.nn.Module): def __init__(self): super().__init__() self.flat = torch.nn.Flatten(0) @export - @annotate_args([ - None, - ([-1, -1, -1, 9, 3, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, 9, 3, -1], torch.float32, True), + ] + ) def forward(self, x): return self.flat(x) @@ -415,15 +437,16 @@ def FlattenDynamicModuleCollapseAll_basic(module, tu: TestUtils): class AliasModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inp_tensor): return torch.ops.aten.alias(inp_tensor) @@ -437,16 +460,17 @@ def AliasModule_basic(module, tu: TestUtils): class ConstantPad2dStaticModule(torch.nn.Module): - def __init__(self): super().__init__() - self.pad2d = torch.nn.ConstantPad2d((0, 1, 2, 3), -float('inf')) + self.pad2d = torch.nn.ConstantPad2d((0, 1, 2, 3), -float("inf")) @export - @annotate_args([ - None, - ([1, 1, 20, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 20, 20], torch.float32, True), + ] + ) def forward(self, x): return self.pad2d(x) @@ -460,15 +484,16 @@ def ConstantPad2dStaticModule_basic(module, tu: TestUtils): class PadModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): pad = [0, 1, 2, 3] mode = "constant" @@ -484,15 +509,16 @@ def PadModule_basic(module, tu: TestUtils): class PadWithNoneValModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): pad = [0, 1, 2, 3] mode = "constant" @@ -508,17 +534,18 @@ def PadWithNoneValModule_basic(module, tu: TestUtils): class ConstantPadNdModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf')) + return torch.ops.aten.constant_pad_nd(x, (0, 1), -float("inf")) @register_test_case(module_factory=lambda: ConstantPadNdModule()) @@ -530,17 +557,18 @@ def ConstantPadNdModule_basic(module, tu: TestUtils): class ConstantPadNdStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 20, 20, 4, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 20, 20, 4, 4], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf')) + return torch.ops.aten.constant_pad_nd(x, (0, 1), -float("inf")) @register_test_case(module_factory=lambda: ConstantPadNdStaticModule()) @@ -552,17 +580,18 @@ def ConstantPadNdStaticModule_basic(module, tu: TestUtils): class ConstantPadNdPartialStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 20, 20, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 20, 20, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf')) + return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float("inf")) @register_test_case(module_factory=lambda: ConstantPadNdPartialStaticModule()) @@ -572,28 +601,31 @@ def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils): # ============================================================================== class ReflectionPad1dModule3dInput(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 2, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 2, 4], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad1d(x, (3,1)) + return torch.ops.aten.reflection_pad1d(x, (3, 1)) + class ReplicationPad2dModule_basic_module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 3, 3], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.replication_pad2d(x, (1, 2, 3, 4)) @@ -602,18 +634,21 @@ class ReplicationPad2dModule_basic_module(torch.nn.Module): def ReplicationPad2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 3, 3, low=-1)) + # ============================================================================== -class ReplicationPad2dModule_left0_module(torch.nn.Module): +class ReplicationPad2dModule_left0_module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 3, 3], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.replication_pad2d(x, (0, 2, 3, 4)) @@ -622,18 +657,21 @@ class ReplicationPad2dModule_left0_module(torch.nn.Module): def ReplicationPad2dModule_left0(module, tu: TestUtils): module.forward(tu.rand(1, 1, 3, 3, low=-1)) + # ============================================================================== -class ReplicationPad2dModule_right0_module(torch.nn.Module): +class ReplicationPad2dModule_right0_module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 3, 3], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.replication_pad2d(x, (1, 0, 3, 4)) @@ -642,18 +680,21 @@ class ReplicationPad2dModule_right0_module(torch.nn.Module): def ReplicationPad2dModule_right0(module, tu: TestUtils): module.forward(tu.rand(1, 1, 3, 3, low=-1)) + # ============================================================================== -class ReplicationPad2dModule_top0_module(torch.nn.Module): +class ReplicationPad2dModule_top0_module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 3, 3], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.replication_pad2d(x, (1, 2, 0, 4)) @@ -662,18 +703,21 @@ class ReplicationPad2dModule_top0_module(torch.nn.Module): def ReplicationPad2dModule_top0(module, tu: TestUtils): module.forward(tu.rand(1, 1, 3, 3, low=-1)) + # ============================================================================== -class ReplicationPad2dModule_bottom0_module(torch.nn.Module): +class ReplicationPad2dModule_bottom0_module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 3, 3], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.replication_pad2d(x, (1, 2, 3, 0)) @@ -682,78 +726,87 @@ class ReplicationPad2dModule_bottom0_module(torch.nn.Module): def ReplicationPad2dModule_bottom0(module, tu: TestUtils): module.forward(tu.rand(1, 1, 3, 3, low=-1)) + # ============================================================================== + @register_test_case(module_factory=lambda: ReflectionPad1dModule3dInput()) def ReflectionPad1dModule3dInput_basic(module, tu: TestUtils): - module.forward(tu.rand(1,2,4)) + module.forward(tu.rand(1, 2, 4)) class ReflectionPad1dModule2dInput(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad1d(x, (3,2)) + return torch.ops.aten.reflection_pad1d(x, (3, 2)) @register_test_case(module_factory=lambda: ReflectionPad1dModule2dInput()) def ReflectionPad1dModule2dInput_basic(module, tu: TestUtils): - module.forward(tu.rand(2,4)) + module.forward(tu.rand(2, 4)) + class ReflectionPad1dModule3dInputLeft(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 4, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 4, 5], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad1d(x, (2,0)) + return torch.ops.aten.reflection_pad1d(x, (2, 0)) @register_test_case(module_factory=lambda: ReflectionPad1dModule3dInputLeft()) def ReflectionPad1dModule3dInput_Left(module, tu: TestUtils): - module.forward(tu.rand(1,4,5)) + module.forward(tu.rand(1, 4, 5)) + class ReflectionPad1dModule2dInputRight(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 6], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad1d(x, (0,3)) + return torch.ops.aten.reflection_pad1d(x, (0, 3)) @register_test_case(module_factory=lambda: ReflectionPad1dModule2dInputRight()) def ReflectionPad1dModule2dInput_Right(module, tu: TestUtils): - module.forward(tu.rand(3,6)) + module.forward(tu.rand(3, 6)) + # ============================================================================== class TransposeIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 2], torch.float32, True), + ] + ) def forward(self, x): return torch.transpose(x, 0, 1) @@ -767,7 +820,6 @@ def TransposeIntModule_basic(module, tu: TestUtils): class PermuteModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -782,12 +834,10 @@ def PermuteModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) - # ============================================================================== class PermuteNegativeIndexModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -806,7 +856,6 @@ def PermuteNegativeIndexModule_basic(module, tu: TestUtils): class Permute0RankModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -825,15 +874,16 @@ def Permute0RankModule_basic(module, tu: TestUtils): class TransposeIntNegDimsModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 2], torch.float32, True), + ] + ) def forward(self, x): return torch.transpose(x, -1, -2) @@ -855,9 +905,10 @@ class PixelShuffleModuleStaticRank4Float32(torch.nn.Module): def forward(self, x): return torch.ops.aten.pixel_shuffle(x, 3) + @register_test_case(module_factory=lambda: PixelShuffleModuleStaticRank4Float32()) def PixelShuffleModuleStaticRank4Float32_basic(module, tu: TestUtils): - module.forward(tu.rand(3,18,2,2)) + module.forward(tu.rand(3, 18, 2, 2)) # ============================================================================== @@ -872,9 +923,11 @@ class PixelShuffleModuleStaticRank3Int64(torch.nn.Module): def forward(self, x): return torch.ops.aten.pixel_shuffle(x, 2) + @register_test_case(module_factory=lambda: PixelShuffleModuleStaticRank3Int64()) def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): - module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) + module.forward(tu.randint(12, 2, 3, low=0, high=100)) + # ============================================================================== @@ -884,13 +937,15 @@ class PixelShuffleModuleFullDynamic(torch.nn.Module): super().__init__() @export - @annotate_args([None, ([-1,-1,-1,-1], torch.int64, True)]) + @annotate_args([None, ([-1, -1, -1, -1], torch.int64, True)]) def forward(self, x): return torch.ops.aten.pixel_shuffle(x, 2) + @register_test_case(module_factory=lambda: PixelShuffleModuleFullDynamic()) def PixelShuffleModuleFullDynamic_basic(module, tu: TestUtils): - module.forward(tu.randint(1,8,3,3, low = 0, high = 100)) + module.forward(tu.randint(1, 8, 3, 3, low=0, high=100)) + # ============================================================================== @@ -900,46 +955,50 @@ class PixelShuffleModuleSpatiallyDynamic(torch.nn.Module): super().__init__() @export - @annotate_args([None, ([2,1,8,-1,-1], torch.int64, True)]) + @annotate_args([None, ([2, 1, 8, -1, -1], torch.int64, True)]) def forward(self, x): return torch.ops.aten.pixel_shuffle(x, 2) + @register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyDynamic()) def PixelShuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils): - module.forward(tu.randint(2,1,8,2,3, low = 0, high = 100)) + module.forward(tu.randint(2, 1, 8, 2, 3, low=0, high=100)) # ============================================================================== + class PixelShuffleModuleSpatiallyStatic(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([None, ([-1,-1,-1,3,1], torch.int64, True)]) + @annotate_args([None, ([-1, -1, -1, 3, 1], torch.int64, True)]) def forward(self, x): return torch.ops.aten.pixel_shuffle(x, 2) + @register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyStatic()) def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): - module.forward(tu.randint(1,2,12,3,1, low = 0, high = 100)) + module.forward(tu.randint(1, 2, 12, 3, 1, low=0, high=100)) # ============================================================================== class TensorsConcatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y, z): return torch.cat([x, y, z], 1) @@ -953,17 +1012,18 @@ def TensorsConcatModule_basic(module, tu: TestUtils): class TensorsConcatNegativeDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y, z): return torch.cat([x, y, z], dim=-2) @@ -977,43 +1037,47 @@ def TensorsConcatNegativeDimModule_basic(module, tu: TestUtils): class TensorsConcatPromoteDTypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.bool, True), - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.bool, True), + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, x, y, z): return torch.cat([x, y, z], dim=-2) @register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeModule()) def TensorsConcatPromoteDTypeModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 2, 4, low=0, high=2).bool(), - tu.randint(2, 1, 4, low=0, high=100).int(), - tu.randint(2, 3, 4, low=0, high=100).long()) + module.forward( + tu.randint(2, 2, 4, low=0, high=2).bool(), + tu.randint(2, 1, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long(), + ) # ============================================================================== class TensorsConcatStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 2, 4], torch.float32, True), - ([2, 1, 4], torch.float32, True), - ([2, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 2, 4], torch.float32, True), + ([2, 1, 4], torch.float32, True), + ([2, 3, 4], torch.float32, True), + ] + ) def forward(self, x, y, z): return torch.cat([x, y, z], dim=1) @@ -1027,17 +1091,18 @@ def TensorsConcatStaticModule_basic(module, tu: TestUtils): class TensorsConcatNegativeDimStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 2, 4], torch.float32, True), - ([2, 1, 4], torch.float32, True), - ([2, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 2, 4], torch.float32, True), + ([2, 1, 4], torch.float32, True), + ([2, 3, 4], torch.float32, True), + ] + ) def forward(self, x, y, z): return torch.cat([x, y, z], dim=-2) @@ -1051,17 +1116,18 @@ def TensorsConcatNegativeDimStaticModule_basic(module, tu: TestUtils): class TensorsStackModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y, z): return torch.stack([x, y, z], dim=1) @@ -1075,15 +1141,16 @@ def TensorsStackModule_basic(module, tu: TestUtils): class TensorsStackSingleElementListModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.stack([x], dim=1) @@ -1097,17 +1164,18 @@ def TensorsStackSingleElementListModule_basic(module, tu: TestUtils): class TensorsStackNegativeDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y, z): return torch.stack([x, y, z], dim=-2) @@ -1121,42 +1189,46 @@ def TensorsStackNegativeDimModule_basic(module, tu: TestUtils): class TensorsStackPromoteDTypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.bool, True), - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.bool, True), + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, x, y, z): return torch.stack([x, y, z], dim=-2) @register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule()) def TensorsStackPromoteDTypeModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3, 4, low=0, high=2).bool(), - tu.randint(2, 3, 4, low=0, high=100).int(), - tu.randint(2, 3, 4, low=0, high=100).long()) + module.forward( + tu.randint(2, 3, 4, low=0, high=2).bool(), + tu.randint(2, 3, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long(), + ) # ============================================================================== class GatherModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, tensor, indices): return torch.gather(tensor, 2, indices) @@ -1170,16 +1242,17 @@ def GatherModule_basic(module, tu: TestUtils): class GatherNegativeDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, tensor, indices): return torch.gather(tensor, -1, indices) @@ -1197,14 +1270,17 @@ class GatherRandomIndexModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, tensor, indices): return torch.gather(tensor, 1, indices) + @register_test_case(module_factory=lambda: GatherRandomIndexModule()) def GatherRandomIndexModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4), tu.randint(2, 3, 4, high=3)) @@ -1218,14 +1294,17 @@ class Gather2DInputModdule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, tensor, indices): return torch.gather(tensor, 1, indices) + @register_test_case(module_factory=lambda: Gather2DInputModdule()) def Gather2DInputModdule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5), torch.tensor([[1, 2, 3], [4, 3, 2]])) @@ -1235,16 +1314,17 @@ def Gather2DInputModdule_basic(module, tu: TestUtils): class GatherStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4], torch.float32, True), - ([1, 2, 3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4], torch.float32, True), + ([1, 2, 3], torch.int64, True), + ] + ) def forward(self, tensor, indices): return torch.gather(tensor, 2, indices) @@ -1258,15 +1338,16 @@ def GatherStaticModule_basic(module, tu: TestUtils): class AddSizeIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, tensor): # This is a workaround for not supporting scalar arguments. # TODO: pass in dim as an argument to the forward method when scalar @@ -1283,15 +1364,16 @@ def AddSizeIntModule_basic(module, tu: TestUtils): class AddSizeIntNegDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, tensor): # This is a workaround for not supporting scalar arguments. # TODO: pass in dim as an argument to the forward method when scalar @@ -1308,16 +1390,17 @@ def AddSizeIntNegDimModule_basic(module, tu: TestUtils): class Add_MixPModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float64, True), + ] + ) def forward(self, a, b): a += b return a @@ -1332,19 +1415,20 @@ def Add_MixPModule_basic(module, tu: TestUtils): class EmbeddingModuleI64(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) - self.embed = torch.nn.Embedding(num_embeddings=100, - embedding_dim=50, - padding_idx=4) + self.embed = torch.nn.Embedding( + num_embeddings=100, embedding_dim=50, padding_idx=4 + ) @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, indices): return self.embed.forward(indices) @@ -1358,19 +1442,20 @@ def EmbeddingModuleI64_basic(module, tu: TestUtils): class EmbeddingModuleI32(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) - self.embed = torch.nn.Embedding(num_embeddings=100, - embedding_dim=50, - padding_idx=4) + self.embed = torch.nn.Embedding( + num_embeddings=100, embedding_dim=50, padding_idx=4 + ) @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, indices): return self.embed.forward(indices) @@ -1379,23 +1464,25 @@ class EmbeddingModuleI32(torch.nn.Module): def EmbeddingModuleI32_basic(module, tu: TestUtils): module.forward(tu.randint(3, 3, high=100).to(torch.int32)) + # ============================================================================== class EmbeddingModuleF16(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) - self.embed = torch.nn.Embedding(num_embeddings=100, - embedding_dim=50, - padding_idx=4).to(torch.half) + self.embed = torch.nn.Embedding( + num_embeddings=100, embedding_dim=50, padding_idx=4 + ).to(torch.half) @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, indices): return self.embed.forward(indices) @@ -1407,20 +1494,22 @@ def EmbeddingModuleF16_basic(module, tu: TestUtils): # ============================================================================== -class EmbeddingModuleI32Static(torch.nn.Module): +class EmbeddingModuleI32Static(torch.nn.Module): def __init__(self): super().__init__() torch.manual_seed(0) - self.embed = torch.nn.Embedding(num_embeddings=100, - embedding_dim=50, - padding_idx=4) + self.embed = torch.nn.Embedding( + num_embeddings=100, embedding_dim=50, padding_idx=4 + ) @export - @annotate_args([ - None, - ([3, 3], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([3, 3], torch.int32, True), + ] + ) def forward(self, indices): return self.embed.forward(indices) @@ -1434,19 +1523,20 @@ def EmbeddingModuleI32Static_basic(module, tu: TestUtils): class EmbeddingModule1DIndices(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) - self.embed = torch.nn.Embedding(num_embeddings=100, - embedding_dim=50, - padding_idx=4) + self.embed = torch.nn.Embedding( + num_embeddings=100, embedding_dim=50, padding_idx=4 + ) @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ] + ) def forward(self, indices): return self.embed.forward(indices) @@ -1460,16 +1550,17 @@ def EmbeddingModule1DIndices_basic(module, tu: TestUtils): class SoftmaxIntModule(torch.nn.Module): - def __init__(self): super().__init__() self.softmax = torch.nn.Softmax(2) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, tensor): return self.softmax.forward(tensor) @@ -1480,15 +1571,16 @@ def SoftmaxIntModule_basic(module, tu: TestUtils): class SoftmaxIntNonNoneDtypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, tensor): return torch.ops.aten.softmax(tensor, dim=2, dtype=torch.float64) @@ -1502,15 +1594,16 @@ def SoftmaxIntNonNoneDtypeModule_basic(module, tu: TestUtils): class _SoftmaxModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, tensor): return torch.ops.aten._softmax(tensor, 0, False) @@ -1524,17 +1617,18 @@ def _SoftmaxModule_basic(module, tu: TestUtils): class SoftmaxIntNegDimModule(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) self.softmax = torch.nn.Softmax(-2) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, tensor): return self.softmax.forward(tensor) @@ -1548,17 +1642,18 @@ def SoftmaxIntNegDimModule_basic(module, tu: TestUtils): class SoftmaxIntArgTypeF64Module(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) self.softmax = torch.nn.Softmax(2) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, tensor): return self.softmax.forward(tensor) @@ -1572,15 +1667,16 @@ def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils): class _LogSoftmaxModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, tensor): return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False) @@ -1594,15 +1690,16 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils): class _LogSoftmaxModuleStable(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, tensor): return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False) @@ -1619,15 +1716,16 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils): class SoftplusModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.softplus(x) @@ -1641,15 +1739,16 @@ def SoftplusModule_basic(module, tu: TestUtils): class HardsigmoidModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.hardsigmoid(x) @@ -1663,15 +1762,16 @@ def HardsigmoidModule_basic(module, tu: TestUtils): class HardsigmoidRandomModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.hardsigmoid(x) @@ -1685,15 +1785,16 @@ def HardsigmoidRandomModule_basic(module, tu: TestUtils): class BroadcastToModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, 1], torch.float32, True), + ] + ) def forward(self, x): return torch.broadcast_to(x, [1, -1, -1, 4]) @@ -1707,16 +1808,17 @@ def BroadcastToModule_basic(module, tu: TestUtils): class BroadcastToSameRankStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 8], torch.float32, True), - ([3, 1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 1, 8], torch.float32, True), + ([3, 1, 1], torch.float32, True), + ] + ) def forward(self, x, y): y = torch.broadcast_to(y, [3, 1, 8]) return torch.ops.aten.sub(x, y) @@ -1731,16 +1833,17 @@ def BroadcastToSameRankStaticModule_basic(module, tu: TestUtils): class BroadcastZeroRankInputStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 8], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 1, 8], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, x, y): y = torch.broadcast_to(y, [3, 1, 8]) return torch.ops.aten.sub(x, y) @@ -1750,19 +1853,22 @@ class BroadcastZeroRankInputStaticModule(torch.nn.Module): def BroadcastZeroRankInputStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 8), tu.rand()) + # ============================================================================== -class BroadcastListConstructWithMinusOneModule(torch.nn.Module): +class BroadcastListConstructWithMinusOneModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 8], torch.float32, True), - ([3, 1, 8], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 1, 8], torch.float32, True), + ([3, 1, 8], torch.float32, True), + ] + ) def forward(self, x, y): y = torch.broadcast_to(y, [-1, -1, -1]) return torch.ops.aten.sub(x, y) @@ -1772,19 +1878,22 @@ class BroadcastListConstructWithMinusOneModule(torch.nn.Module): def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 8), tu.rand(3, 1, 8)) + # ============================================================================== -class BroadcastDynamicDimModule(torch.nn.Module): +class BroadcastDynamicDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, -1, 1, -1], torch.float32, True), - ([1, -1, 1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, -1, 1, -1], torch.float32, True), + ([1, -1, 1, -1], torch.float32, True), + ] + ) def forward(self, x, y): dim_at_index_1 = torch.ops.aten.size(x, 1) dim_at_index_3 = torch.ops.aten.size(x, 3) @@ -1801,15 +1910,16 @@ def BroadcastDynamicDimModule_basic(module, tu: TestUtils): class RollModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, -1, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, -1, 2], torch.float32, True), + ] + ) def forward(self, x): return x.roll([2, -1], [0, 2]) @@ -1818,19 +1928,21 @@ class RollModule(torch.nn.Module): def RollModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) + # ============================================================================== class RepeatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 1, 2], torch.float32, True), + ] + ) def forward(self, x): return x.repeat([2, 1, 3, 4]) @@ -1839,19 +1951,21 @@ class RepeatModule(torch.nn.Module): def RepeatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) + # ============================================================================== class RepeatInterleaveSelfIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ] + ) def forward(self, x): return x.repeat_interleave(2, 1) @@ -1860,19 +1974,21 @@ class RepeatInterleaveSelfIntModule(torch.nn.Module): def RepeatInterleaveSelfIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ] + ) def forward(self, x): return x.repeat_interleave(2) @@ -1881,18 +1997,21 @@ class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module): def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== -class TileSmallDimsSizeModule(torch.nn.Module): +class TileSmallDimsSizeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 1, 2], torch.float32, True), + ] + ) def forward(self, x): return x.tile([3, 4]) @@ -1901,18 +2020,21 @@ class TileSmallDimsSizeModule(torch.nn.Module): def TileSmallDimsSizeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) + # ============================================================================== -class TileBigDimsSizeModule(torch.nn.Module): +class TileBigDimsSizeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 1, 2], torch.float32, True), + ] + ) def forward(self, x): return x.tile([3, 4, 5, 6]) @@ -1921,19 +2043,21 @@ class TileBigDimsSizeModule(torch.nn.Module): def TileBigDimsSizeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) + # ============================================================================== class ExpandModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, 1], torch.float32, True), + ] + ) def forward(self, x): return x.expand([1, -1, -1, 4]) @@ -1947,15 +2071,16 @@ def ExpandModule_basic(module, tu: TestUtils): class ContiguousModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return x.contiguous() @@ -1969,16 +2094,17 @@ def ContiguousModule_basic(module, tu: TestUtils): class LogSoftmaxIntModule(torch.nn.Module): - def __init__(self): super().__init__() self.log_softmax = torch.nn.LogSoftmax(2) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, tensor): return self.log_softmax.forward(tensor) @@ -1990,15 +2116,17 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils): # ============================================================================== -class PrimMinIntModule(torch.nn.Module): +class PrimMinIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.prim.min(1, -1) @@ -2010,16 +2138,18 @@ def PrimMinIntModule_basic(module, tu: TestUtils): # ============================================================================== -class PrimMinIntDynamicModule(torch.nn.Module): +class PrimMinIntDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.prim.min(a.size(0), a.size(1)) @@ -2031,16 +2161,18 @@ def PrimMinIntDynamicModule_basic(module, tu: TestUtils): # ============================================================================== -class PrimMaxIntModule(torch.nn.Module): +class PrimMaxIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.prim.max(a.size(0), a.size(1)) @@ -2052,15 +2184,17 @@ def PrimMaxIntModule_basic(module, tu: TestUtils): # ============================================================================== -class NumToTensorIntModule(torch.nn.Module): +class NumToTensorIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.prim.NumToTensor(1) @@ -2074,14 +2208,15 @@ def NumToTensorIntModule_basic(module, tu: TestUtils): class NumToTensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.prim.NumToTensor(1.0) @@ -2096,17 +2231,18 @@ def NumToTensorFloatModule_basic(module, tu: TestUtils): # This test can be removed once we have one real op returning 3 float32 tensors class ReturnThreeTensorFloat32(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b, c): return a, b, c @@ -2120,17 +2256,18 @@ def ReturnThreeTensorFloat32_basic(module, tu: TestUtils): class AddCMulModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, tensor1, tensor2): return torch.addcmul(input, tensor1, tensor2, value=1.0) @@ -2144,17 +2281,18 @@ def AddCMulModule_basic(module, tu: TestUtils): class AddCDivModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, tensor1, tensor2): return torch.addcdiv(input, tensor1, tensor2, value=1.0) @@ -2168,14 +2306,15 @@ def AddCDivModule_basic(module, tu: TestUtils): class tensorIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = 1 return torch.tensor(a) @@ -2190,14 +2329,15 @@ def TensorIntModule_basic(module, tu: TestUtils): class tensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = 1.0 return torch.tensor(a) @@ -2212,15 +2352,16 @@ def TensorFloatModule_basic(module, tu: TestUtils): class DropoutEvalIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.dropout(x, 0.2, train=False) @@ -2234,15 +2375,16 @@ def DropoutEvalIntModule_basic(module, tu: TestUtils): class DropoutEvalFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.dropout(x, 0.1, train=False) @@ -2256,15 +2398,16 @@ def DropoutEvalFloatModule_basic(module, tu: TestUtils): class DropoutTrainModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): res = torch.dropout(x, 0.3, train=True) return torch.mean(res), torch.std(res) @@ -2274,19 +2417,21 @@ class DropoutTrainModule(torch.nn.Module): def DropoutTrainModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1536)) + # ============================================================================== class DropoutTrainStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1024, 1536], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1024, 1536], torch.float32, True), + ] + ) def forward(self, x): res = torch.dropout(x, 0.3, train=True) return torch.mean(res), torch.std(res) @@ -2296,19 +2441,21 @@ class DropoutTrainStaticShapeModule(torch.nn.Module): def DropoutTrainStaticShapeModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1536)) + # ============================================================================== class NativeDropoutEvalFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.native_dropout(x, 0.1, train=False) @@ -2322,18 +2469,24 @@ def NativeDropoutEvalFloatModule_basic(module, tu: TestUtils): class NativeDropoutTrainModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): res = torch.native_dropout(x, 0.3, train=True) - return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32)) + return ( + torch.mean(res[0]), + torch.std(res[0]), + torch.mean(res[1].to(torch.float32)), + torch.std(res[1].to(torch.float32)), + ) @register_test_case(module_factory=lambda: NativeDropoutTrainModule()) @@ -2345,37 +2498,45 @@ def NativeDropoutTrainModule_basic(module, tu: TestUtils): class NativeDropoutTrainStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1024, 1536], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1024, 1536], torch.float32, True), + ] + ) def forward(self, x): res = torch.native_dropout(x, 0.3, train=True) - return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32)) + return ( + torch.mean(res[0]), + torch.std(res[0]), + torch.mean(res[1].to(torch.float32)), + torch.std(res[1].to(torch.float32)), + ) @register_test_case(module_factory=lambda: NativeDropoutTrainStaticShapeModule()) def NativeDropoutTrainStaticShapeModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1536)) + # ============================================================================== class NumelModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input): return torch.ops.aten.numel(input) @@ -2389,15 +2550,16 @@ def NumelModule_basic(module, tu: TestUtils): class NumelZeroRankModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, input): return torch.ops.aten.numel(input) @@ -2411,15 +2573,16 @@ def NumelZeroRankModule_basic(module, tu: TestUtils): class BoolTensorReturnFalseModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) def forward(self, a): return a @@ -2433,15 +2596,16 @@ def BoolTensorReturnFalseModule_basic(module, tu: TestUtils): class BoolTensorReturnTrueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) def forward(self, a): return a @@ -2455,15 +2619,16 @@ def BoolTensorReturnTrueModule_basic(module, tu: TestUtils): class BoolTensorReturnMixedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a): return a @@ -2477,16 +2642,17 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils): class BoolTensorHandleSignless(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a, b): return a * b @@ -2502,15 +2668,16 @@ def BoolTensorHandleSignless_basic(module, tu: TestUtils): class TModuleRank2(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, lhs): return torch.t(lhs) @@ -2524,15 +2691,16 @@ def TModuleRank2_basic(module, tu: TestUtils): class TModuleRank1(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, lhs): return torch.t(lhs) @@ -2546,15 +2714,16 @@ def TModuleRank1_basic(module, tu: TestUtils): class TModuleRank0(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) def forward(self, lhs): return torch.t(lhs) @@ -2568,16 +2737,17 @@ def TModuleRank0_basic(module, tu: TestUtils): class TensorLiteralModule(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) self.register_buffer("t", torch.randint(-5, 5, (2, 3))) @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.add(self.t, self.t) @@ -2591,16 +2761,17 @@ def TensorLiteralModule_basic(module, tu: TestUtils): class TensorOpaqueLiteralModule(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) self.register_buffer("t", torch.randint(-5, 5, (256, 1024))) @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.add(self.t, self.t) @@ -2614,16 +2785,17 @@ def TensorOpaqueLiteralModule_basic(module, tu: TestUtils): class ReturnTwoTensorF32I64(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a, b): return a, b @@ -2637,18 +2809,19 @@ def ReturnTwoTensorF32I64_basic(module, tu: TestUtils): class IndexTensorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, index): - return torch.ops.aten.index(x, (index, )) + return torch.ops.aten.index(x, (index,)) @register_test_case(module_factory=lambda: IndexTensorModule()) @@ -2658,39 +2831,42 @@ def IndexTensorModule_basic(module, tu: TestUtils): # ============================================================================== class IndexTensorStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5], torch.float32, True), - ([2, 3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([4, 5], torch.float32, True), + ([2, 3], torch.int64, True), + ] + ) def forward(self, x, index): - return torch.ops.aten.index(x, (index, )) + return torch.ops.aten.index(x, (index,)) @register_test_case(module_factory=lambda: IndexTensorStaticModule()) def IndexTensorStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4)) + # ============================================================================== class IndexTensorMultiIndexStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5], torch.float32, True), - ([2, 3], torch.int64, True), - ([2, 3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([4, 5], torch.float32, True), + ([2, 3], torch.int64, True), + ([2, 3], torch.int64, True), + ] + ) def forward(self, x, index1, index2): return torch.ops.aten.index(x, (index1, index2)) @@ -2704,16 +2880,17 @@ def IndexTensorMultiIndexStaticModule_basic(module, tu: TestUtils): class IndexTensorModule3dInput(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, index): return torch.ops.aten.index(x, (index,)) @@ -2725,18 +2902,20 @@ def IndexTensorModule3dInput_basic(module, tu: TestUtils): # ============================================================================== -class IndexTensorStaticContiguousWithNoneModule(torch.nn.Module): +class IndexTensorStaticContiguousWithNoneModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4, 5, 32], torch.float32, True), - ([1, 2, 1], torch.int64, True), - ([2, 1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4, 5, 32], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ] + ) def forward(self, x, index, index1): return torch.ops.aten.index(x, (None, index, index1, None)) @@ -2744,95 +2923,124 @@ class IndexTensorStaticContiguousWithNoneModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorStaticContiguousWithNoneModule()) def IndexTensorStaticContiguousWithNoneModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]])) + module.forward( + tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0], [1]]]), torch.tensor([[0], [1]]) + ) + # ============================================================================== class IndexTensorDyanmicInputContiguousWithNoneModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([1, 2, 1], torch.int64, True), - ([2, 1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ] + ) def forward(self, x, index, index1): return torch.ops.aten.index(x, (None, index, index1, None)) -@register_test_case(module_factory=lambda: IndexTensorDyanmicInputContiguousWithNoneModule()) +@register_test_case( + module_factory=lambda: IndexTensorDyanmicInputContiguousWithNoneModule() +) def IndexTensorDyanmicInputContiguousWithNoneModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]])) + module.forward( + tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0], [1]]]), torch.tensor([[0], [1]]) + ) + # ============================================================================== class IndexTensorStaticNonContiguousWithNoneModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4, 5, 32], torch.float32, True), - ([1, 2, 1], torch.int64, True), - ([2, 1], torch.int64, True), - ([2, 1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4, 5, 32], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ] + ) def forward(self, x, index, index1, index2): return torch.ops.aten.index(x, (None, index, index1, None, index2)) -@register_test_case(module_factory=lambda: IndexTensorStaticNonContiguousWithNoneModule()) +@register_test_case( + module_factory=lambda: IndexTensorStaticNonContiguousWithNoneModule() +) def IndexTensorStaticNonContiguousWithNoneModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]]), torch.tensor([[0],[1]])) + module.forward( + tu.rand(2, 3, 4, 5, 32), + torch.tensor([[[0], [1]]]), + torch.tensor([[0], [1]]), + torch.tensor([[0], [1]]), + ) + # ============================================================================== -class IndexTensorDyanmicInputNonContiguousWithNoneModule(torch.nn.Module): +class IndexTensorDyanmicInputNonContiguousWithNoneModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([1, 2, 1], torch.int64, True), - ([2, 1], torch.int64, True), - ([2, 1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([1, 2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ([2, 1], torch.int64, True), + ] + ) def forward(self, x, index, index1, index2): return torch.ops.aten.index(x, (None, index, index1, None, index2)) -@register_test_case(module_factory=lambda: IndexTensorDyanmicInputNonContiguousWithNoneModule()) +@register_test_case( + module_factory=lambda: IndexTensorDyanmicInputNonContiguousWithNoneModule() +) def IndexTensorDyanmicInputNonContiguousWithNoneModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5, 32), torch.tensor([[[0],[1]]]), torch.tensor([[0],[1]]), torch.tensor([[0],[1]])) + module.forward( + tu.rand(2, 3, 4, 5, 32), + torch.tensor([[[0], [1]]]), + torch.tensor([[0], [1]]), + torch.tensor([[0], [1]]), + ) + # ============================================================================== class IndexTensorSelectDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a, ind): return torch.ops.aten.index(a, (None, ind, None)) @@ -2841,23 +3049,31 @@ class IndexTensorSelectDimModule(torch.nn.Module): def IndexTensorSelectDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 6), tu.randint(2, 3, high=3)) + # ============================================================================== class IndexTensorMultiInput(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([3, 3], torch.int64, True), - ([3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([3, 3], torch.int64, True), + ([3], torch.int64, True), + ] + ) def forward(self, x, index1, index2): - return torch.ops.aten.index(x, (index1, index2,)) + return torch.ops.aten.index( + x, + ( + index1, + index2, + ), + ) @register_test_case(module_factory=lambda: IndexTensorMultiInput()) @@ -2869,19 +3085,26 @@ def IndexTensorMultiInput_basic(module, tu: TestUtils): class IndexTensorMultiInputOneDim(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([6, 1], torch.int64, True), - ([3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([6, 1], torch.int64, True), + ([3], torch.int64, True), + ] + ) def forward(self, x, index1, index2): - return torch.ops.aten.index(x, (index1, index2,)) + return torch.ops.aten.index( + x, + ( + index1, + index2, + ), + ) @register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim()) @@ -2893,188 +3116,209 @@ def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils): class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, 1], torch.int64, True), - ([-1], torch.int64, True), - ]) - def forward(self, x, index1, index2): - return torch.ops.aten.index(x, ( + @annotate_args( + [ None, - index1, - index2, - )) + ([-1, -1, -1], torch.float32, True), + ([-1, 1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) + def forward(self, x, index1, index2): + return torch.ops.aten.index( + x, + ( + None, + index1, + index2, + ), + ) @register_test_case( - module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic()) + module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic() +) def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), - tu.randint(3, high=3)) + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), tu.randint(3, high=3)) # ============================================================================== class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, 1], torch.int64, True), - ([-1], torch.int64, True), - ]) - def forward(self, x, index1, index2): - return torch.ops.aten.index(x, ( - index1, + @annotate_args( + [ None, - index2, - )) + ([-1, -1, -1], torch.float32, True), + ([-1, 1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) + def forward(self, x, index1, index2): + return torch.ops.aten.index( + x, + ( + index1, + None, + index2, + ), + ) @register_test_case( - module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic()) -def IndexTensorMultiInputNonContiguousOneDimDynamic_basic( - module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), - tu.randint(3, high=3)) + module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic() +) +def IndexTensorMultiInputNonContiguousOneDimDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), tu.randint(3, high=3)) # ============================================================================== class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, 2], torch.int64, True), - ([-1], torch.int64, True), - ]) - def forward(self, x, index1, index2): - return torch.ops.aten.index(x, ( - index2, + @annotate_args( + [ None, - index1, - )) + ([-1, -1, -1], torch.float32, True), + ([-1, 2], torch.int64, True), + ([-1], torch.int64, True), + ] + ) + def forward(self, x, index1, index2): + return torch.ops.aten.index( + x, + ( + index2, + None, + index1, + ), + ) -@register_test_case( - module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic()) +@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic()) def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), tu.randint(6, 2, high=2), - tu.randint(2, high=3)) + module.forward(tu.rand(5, 4, 3), tu.randint(6, 2, high=2), tu.randint(2, high=3)) # ============================================================================== class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([4, 1], torch.int64, True), - ([1, 3], torch.int64, True), - ([-1, 3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 1], torch.int64, True), + ([1, 3], torch.int64, True), + ([-1, 3], torch.int64, True), + ] + ) def forward(self, x, index1, index2, index3): return torch.ops.aten.index(x, (index1, index2, index3)) -@register_test_case(module_factory=lambda: - IndexTensorMultiInputNonContiguousMultipleStaticDims()) -def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic( - module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3), - tu.randint(1, 3, high=1), tu.randint(4, 3, high=1)) +@register_test_case( + module_factory=lambda: IndexTensorMultiInputNonContiguousMultipleStaticDims() +) +def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic(module, tu: TestUtils): + module.forward( + tu.rand(5, 4, 3, 2), + tu.randint(4, 1, high=3), + tu.randint(1, 3, high=1), + tu.randint(4, 3, high=1), + ) # ============================================================================== class IndexTensorMultiInputNonContiguous(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([4, 2], torch.int64, True), - ([4, 2], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 2], torch.int64, True), + ([4, 2], torch.int64, True), + ] + ) def forward(self, x, index1, index2): return torch.ops.aten.index(x, (index1, None, index2)) @register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous()) def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3), tu.randint(4, 2, high=1)) + module.forward( + tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3), tu.randint(4, 2, high=1) + ) # ============================================================================== class IndexTensorMultiInputThreeIndexers(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1, -1], torch.float32, True), - ([8, 4, 2], torch.int64, True), - ([8, 1, 1], torch.int64, True), - ([4, 2], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ([8, 4, 2], torch.int64, True), + ([8, 1, 1], torch.int64, True), + ([4, 2], torch.int64, True), + ] + ) def forward(self, x, index1, index2, index3): return torch.ops.aten.index(x, (None, None, index1, None, index2, index3)) @register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers()) def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 2, 4, 4, 5, 3), - tu.randint(8, 4, 2, high=3), - tu.randint(8, 1, 1, high=4), - tu.randint(4, 2, high=2)) + module.forward( + tu.rand(1, 2, 4, 4, 5, 3), + tu.randint(8, 4, 2, high=3), + tu.randint(8, 1, 1, high=4), + tu.randint(4, 2, high=2), + ) # ============================================================================== class IndexTensorMultiInputContiguousCenter(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([2, 2], torch.int64, True), - ([2], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([2, 2], torch.int64, True), + ([2], torch.int64, True), + ] + ) def forward(self, x, index1, index2): return torch.ops.aten.index(x, (None, index1, index2, None)) @@ -3088,16 +3332,17 @@ def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils): class IndexTensorNegativeIndexModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 2, 3, 2], torch.float32, True), - ([1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([1, 2, 3, 2], torch.float32, True), + ([1], torch.int64, True), + ] + ) def forward(self, x, index): return torch.ops.aten.index(x, (None, None, index)) @@ -3111,16 +3356,17 @@ def IndexTensorNegativeIndexModule_basic(module, tu: TestUtils): class IndexTensorHackedTwinModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, index): return torch.ops.aten.index(x, [index]) @@ -3134,22 +3380,22 @@ def IndexTensorHackedTwinModule_basic(module, tu: TestUtils): class IndexTensorHackedTwinModule3dInput(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, index): return torch.ops.aten.index(x, [index]) -@register_test_case( - module_factory=lambda: IndexTensorHackedTwinModule3dInput()) +@register_test_case(module_factory=lambda: IndexTensorHackedTwinModule3dInput()) def IndexTensorHackedTwinModule3dInput_basic(module, tu: TestUtils): module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) @@ -3157,46 +3403,52 @@ def IndexTensorHackedTwinModule3dInput_basic(module, tu: TestUtils): # ============================================================================== -class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims( - torch.nn.Module): - +class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([4, 1], torch.int64, True), - ([1, 3], torch.int64, True), - ([-1, 3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 1], torch.int64, True), + ([1, 3], torch.int64, True), + ([-1, 3], torch.int64, True), + ] + ) def forward(self, x, index1, index2, index3): return torch.ops.aten.index(x, [index1, index2, index3]) @register_test_case( - module_factory=lambda: - IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims()) + module_factory=lambda: IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims() +) def IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic( - module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3), - tu.randint(1, 3, high=1), tu.randint(4, 3, high=1)) + module, tu: TestUtils +): + module.forward( + tu.rand(5, 4, 3, 2), + tu.randint(4, 1, high=3), + tu.randint(1, 3, high=1), + tu.randint(4, 3, high=1), + ) # ============================================================================== class SquareModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.square(x) @@ -3210,15 +3462,16 @@ def SquareModule_basic(module, tu: TestUtils): class HardswishModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.hardswish(x) @@ -3232,15 +3485,16 @@ def HardswishModule_basic(module, tu: TestUtils): class HardswishRandomModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.hardswish(x) @@ -3254,15 +3508,16 @@ def HardswishRandomModule_basic(module, tu: TestUtils): class SiluModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.silu(x) @@ -3276,15 +3531,16 @@ def SiluModule_basic(module, tu: TestUtils): class HardTanhModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.hardtanh(x, min_val=-2, max_val=2) @@ -3298,15 +3554,16 @@ def HardTanhModule_basic(module, tu: TestUtils): class HardTanhIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.ops.aten.hardtanh(x, min_val=-2, max_val=2) @@ -3320,15 +3577,16 @@ def HardTanhIntModule_basic(module, tu: TestUtils): class BincountModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ] + ) def forward(self, x): return torch.ops.aten.bincount(x) @@ -3342,15 +3600,16 @@ def BincountModule_basic(module, tu: TestUtils): class BincountStaticSizeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([200], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([200], torch.int64, True), + ] + ) def forward(self, x): return torch.ops.aten.bincount(x) @@ -3364,15 +3623,16 @@ def BincountStaticSizeModule_basic(module, tu: TestUtils): class BincountMinlengthModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ] + ) def forward(self, x): return torch.ops.aten.bincount(x, minlength=600) @@ -3386,16 +3646,17 @@ def BincountMinlengthModule_basic(module, tu: TestUtils): class ExpandAsFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, 1, 1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, 1, 1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.expand_as(x, y) @@ -3406,40 +3667,41 @@ def ExpandAsFloatModule_basic(module, tu: TestUtils): class ExpandAsIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 1], torch.int64, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.expand_as(x, y) @register_test_case(module_factory=lambda: ExpandAsIntModule()) def ExpandAsIntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(1, 1, 1, high=100), - tu.randint(4, 5, 6, high=200)) + module.forward(tu.randint(1, 1, 1, high=100), tu.randint(4, 5, 6, high=200)) # ============================================================================== class CopyModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.copy_(x, y) @@ -3450,16 +3712,17 @@ def CopyModule_basic(module, tu: TestUtils): class CopyWithDifferentSizesModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, 4], torch.float32, True), - ([-1, -1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, 4], torch.float32, True), + ([-1, -1, 1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.copy_(x, y) @@ -3470,16 +3733,17 @@ def CopyWithDifferentSizesModule_basic(module, tu: TestUtils): class CopyWithDifferentDTypesModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.copy_(x, y) @@ -3490,22 +3754,22 @@ def CopyWithDifferentDTypesModule_basic(module, tu: TestUtils): class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, 4], torch.float32, True), - ([-1, -1, 1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, 4], torch.float32, True), + ([-1, -1, 1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.copy_(x, y) -@register_test_case( - module_factory=lambda: CopyWithDifferentDTypesAndSizesModule()) +@register_test_case(module_factory=lambda: CopyWithDifferentDTypesAndSizesModule()) def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4), tu.randint(3, 2, 1, high=1000)) @@ -3514,15 +3778,16 @@ def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils): class ToCopyModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten._to_copy(x) @@ -3533,15 +3798,16 @@ def ToCopyModule_basic(module, tu: TestUtils): class ToCopyWithDTypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten._to_copy(x, dtype=torch.int64) @@ -3552,35 +3818,36 @@ def ToCopyWithDTypeModule_basic(module, tu: TestUtils): class ToCopyWithDTypeFalsePinMemoryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten._to_copy(x, dtype=torch.int64, pin_memory=False) -@register_test_case( - module_factory=lambda: ToCopyWithDTypeFalsePinMemoryModule()) +@register_test_case(module_factory=lambda: ToCopyWithDTypeFalsePinMemoryModule()) def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) class ToCopyBoolDTypeStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 5, 5], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 5, 5], torch.uint8, True), + ] + ) def forward(self, x): return torch.ops.aten._to_copy(x, dtype=torch.bool) @@ -3594,15 +3861,16 @@ def ToCopyBoolDTypeStaticModule_basic(module, tu: TestUtils): class FlipModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.flip(x, [1, 2]) @@ -3611,19 +3879,21 @@ class FlipModule(torch.nn.Module): def FlipModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) + # ============================================================================== class FlipModuleStaticShape(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 2, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 2, 4], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.flip(x, [1, 2]) @@ -3632,19 +3902,21 @@ class FlipModuleStaticShape(torch.nn.Module): def FlipModuleStaticShape_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) + # ============================================================================== class FlipNegativeIndexModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 2, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 2, 4], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.flip(x, [-1]) @@ -3658,15 +3930,16 @@ def FlipNegativeIndexModule_basic(module, tu: TestUtils): class DetachModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.detach(x) @@ -3680,15 +3953,16 @@ def DetachModule_basic(module, tu: TestUtils): class LenStrModule(torch.nn.Module): - def __init__(self): super().__init__() self.str = "test" @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.len(self.str) @@ -3700,19 +3974,22 @@ def LenStrModule_basic(module, tu: TestUtils): # ============================================================================== -class IntFloatModule(torch.nn.Module): +class IntFloatModule(torch.nn.Module): def __init__(self): super().__init__() self.value = 1.0 @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.Int(self.value) + @register_test_case(module_factory=lambda: IntFloatModule()) def IntFloatModule_basic(module, tu: TestUtils): module.forward() @@ -3720,20 +3997,23 @@ def IntFloatModule_basic(module, tu: TestUtils): # ============================================================================== -class AtenSubFloatModule(torch.nn.Module): +class AtenSubFloatModule(torch.nn.Module): def __init__(self): super().__init__() self.value1 = 1.0 self.value2 = 2.0 @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return float(torch.ops.aten.sub(self.value1, self.value2)) + @register_test_case(module_factory=lambda: AtenSubFloatModule()) def AtenSubFloatModule_basic(module, tu: TestUtils): module.forward() @@ -3741,16 +4021,18 @@ def AtenSubFloatModule_basic(module, tu: TestUtils): # ============================================================================== -class ScalarImplicitFloatModule(torch.nn.Module): +class ScalarImplicitFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ] + ) def forward(self, x): return float(torch.ops.aten.ScalarImplicit(x)) @@ -3761,15 +4043,16 @@ def ScalarImplicitFloatModule_basic(module, tu: TestUtils): class ScalarImplicitIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, x): return int(torch.ops.aten.ScalarImplicit(x)) @@ -3783,15 +4066,16 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils): class FloatImplicitModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ] + ) def forward(self, x): return float(torch.ops.aten.FloatImplicit(x)) @@ -3805,15 +4089,16 @@ def FloatImplicitModule_basic(module, tu: TestUtils): class IntImplicitModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, x): return float(torch.ops.aten.IntImplicit(x)) @@ -3825,38 +4110,44 @@ def IntImplicitModule_basic(module, tu: TestUtils): # ============================================================================== -class PowIntFloat(torch.nn.Module): +class PowIntFloat(torch.nn.Module): def __init__(self): super().__init__() self.value = 2 self.power_value = 3.0 @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.pow(self.value, self.power_value) + @register_test_case(module_factory=lambda: IntFloatModule()) def PowIntFloatModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== -class BaddbmmDynamicModule(torch.nn.Module): +class BaddbmmDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, batch1, batch2): return torch.ops.aten.baddbmm(input, batch1, batch2) @@ -3867,17 +4158,18 @@ def BaddbmmDynamicModule_basic(module, tu: TestUtils): class BaddbmmStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 2, 7], torch.float32, True), - ([5, 2, 9], torch.float32, True), - ([5, 9, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 2, 7], torch.float32, True), + ([5, 2, 9], torch.float32, True), + ([5, 9, 7], torch.float32, True), + ] + ) def forward(self, input, batch1, batch2): return torch.ops.aten.baddbmm(input, batch1, batch2) @@ -3888,17 +4180,18 @@ def BaddbmmStaticModule_basic(module, tu: TestUtils): class BaddbmmWithAlphaModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, batch1, batch2): return torch.ops.aten.baddbmm(input, batch1, batch2, alpha=5) @@ -3909,17 +4202,18 @@ def BaddbmmWithAlphaModule_basic(module, tu: TestUtils): class BaddbmmWithBetaModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, batch1, batch2): return torch.ops.aten.baddbmm(input, batch1, batch2, beta=0.5) @@ -3930,17 +4224,18 @@ def BaddbmmWithBetaModule_basic(module, tu: TestUtils): class BaddbmmWithAlphaBetaModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, batch1, batch2): return torch.ops.aten.baddbmm(input, batch1, batch2, beta=6, alpha=2.4) @@ -3951,38 +4246,46 @@ def BaddbmmWithAlphaBetaModule_basic(module, tu: TestUtils): class BaddbmmBroadcast1DInputModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1], torch.float32, True), - ([5, 2, 9], torch.float32, True), - ([5, 9, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1], torch.float32, True), + ([5, 2, 9], torch.float32, True), + ([5, 9, 7], torch.float32, True), + ] + ) def forward(self, input, batch1, batch2): return torch.ops.aten.baddbmm(input, batch1, batch2) @register_test_case(module_factory=lambda: BaddbmmBroadcast1DInputModule()) def BaddbmmBroadcast1DInputModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1,), tu.rand(5, 2, 9), tu.rand(5, 9, 7)) + module.forward( + tu.rand( + 1, + ), + tu.rand(5, 2, 9), + tu.rand(5, 9, 7), + ) class BaddbmmBroadcast2DInputModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 7], torch.float32, True), - ([5, 2, 9], torch.float32, True), - ([5, 9, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 7], torch.float32, True), + ([5, 2, 9], torch.float32, True), + ([5, 9, 7], torch.float32, True), + ] + ) def forward(self, input, batch1, batch2): return torch.ops.aten.baddbmm(input, batch1, batch2) @@ -3996,15 +4299,16 @@ def BaddbmmBroadcast2DInputModule_basic(module, tu: TestUtils): class NumpyTRankNStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5, 6], torch.float32, True), + ] + ) def forward(self, lhs): return torch.ops.aten.numpy_T(lhs) @@ -4015,15 +4319,16 @@ def NumpyTRankNStaticModule_basic(module, tu: TestUtils): class NumpyTRankNDynamicModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, lhs): return torch.ops.aten.numpy_T(lhs) @@ -4034,15 +4339,16 @@ def NumpyTRankNDynamicModule_basic(module, tu: TestUtils): class NumpyTRank2Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, lhs): return torch.ops.aten.numpy_T(lhs) @@ -4053,15 +4359,16 @@ def NumpyTRank2Module_basic(module, tu: TestUtils): class NumpyTRank1Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, lhs): return torch.ops.aten.numpy_T(lhs) @@ -4072,15 +4379,16 @@ def NumpyTRank1Module_basic(module, tu: TestUtils): class NumpyTRank0Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) def forward(self, lhs): return torch.ops.aten.numpy_T(lhs) @@ -4094,27 +4402,30 @@ def NumpyTRank0Module_basic(module, tu: TestUtils): class AtenEmbeddingBagStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 2], torch.float32, True), - ([3], torch.int64, True), - ([1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([4, 2], torch.float32, True), + ([3], torch.int64, True), + ([1], torch.int64, True), + ] + ) def forward(self, weight, indices, offsets): - return torch.ops.aten.embedding_bag(weight, - indices, - offsets, - scale_grad_by_freq=False, - mode=0, - sparse=False, - per_sample_weights=None, - include_last_offset=False, - padding_idx=None) + return torch.ops.aten.embedding_bag( + weight, + indices, + offsets, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=None, + ) @register_test_case(module_factory=lambda: AtenEmbeddingBagStaticModule()) @@ -4126,50 +4437,55 @@ def AtenEmbeddingBagStaticModule_basic(module, tu: TestUtils): class AtenEmbeddingBagSumExample(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, weight, indices, offsets): - return torch.ops.aten.embedding_bag(weight, - indices, - offsets, - scale_grad_by_freq=False, - mode=0, - sparse=False, - per_sample_weights=None, - include_last_offset=False, - padding_idx=None) + return torch.ops.aten.embedding_bag( + weight, + indices, + offsets, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=None, + ) @register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample()) def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils): weight = tu.rand(100, 10) indices = torch.LongTensor( - [0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) + [0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54] + ) offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) module.forward(weight, indices, offsets) class Aten_EmbeddingBagExample(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, weight, indices, offsets): return torch.ops.aten._embedding_bag(weight, indices, offsets) @@ -4178,23 +4494,26 @@ class Aten_EmbeddingBagExample(torch.nn.Module): def Aten_EmbeddingBagExample_basic(module, tu: TestUtils): weight = tu.rand(100, 10) indices = torch.LongTensor( - [0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) + [0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54] + ) offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) module.forward(weight, indices, offsets) # ============================================================================== -class CumsumModule(torch.nn.Module): +class CumsumModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, val): # the onnx cumsum op uses a constant 1d tensor # to specify the dimension along which to do cumsum @@ -4204,80 +4523,97 @@ class CumsumModule(torch.nn.Module): ones = torch.ones([1], dtype=torch.int32) return torch.ops.aten.cumsum(val, ones.item()) + @register_test_case(module_factory=lambda: CumsumModule()) def CumsumModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7, 4)) -class CumsumStaticModule(torch.nn.Module): +class CumsumStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 7, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) def forward(self, val): return torch.ops.aten.cumsum(val, 1) + @register_test_case(module_factory=lambda: CumsumStaticModule()) def CumsumStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7, 4)) -class CumsumStaticNegativeDimModule(torch.nn.Module): +class CumsumStaticNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 7, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) def forward(self, val): return torch.ops.aten.cumsum(val, dim=-1) + @register_test_case(module_factory=lambda: CumsumStaticNegativeDimModule()) def CumsumStaticNegativeDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7, 4)) -class CumsumInputDtypeInt32Module(torch.nn.Module): +class CumsumInputDtypeInt32Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 7, 4], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([2, 7, 4], torch.int32, True), + ] + ) def forward(self, val): return torch.ops.aten.cumsum(val, 1) + @register_test_case(module_factory=lambda: CumsumInputDtypeInt32Module()) def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils): module.forward(tu.randint(2, 7, 4).to(torch.int32)) + # ============================================================================== + class AtenToDeviceModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1 , -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, val): - return torch.ops.aten.to(val, device='cpu', dtype=torch.float, non_blocking=False) + return torch.ops.aten.to( + val, device="cpu", dtype=torch.float, non_blocking=False + ) + @register_test_case(module_factory=lambda: AtenToDeviceModule()) def AtenToDeviceModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + # ============================================================================== @@ -4286,14 +4622,16 @@ class Aten_CastFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([2, 4], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([2, 4], torch.int64, True), + ] + ) def forward(self, val): return torch.ops.aten._cast_Float(val) - + + @register_test_case(module_factory=lambda: Aten_CastFloatModule()) def Aten_CastFloatModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 4)) @@ -4304,14 +4642,16 @@ class Aten_CastLongModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([2, 4], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ] + ) def forward(self, val): return torch.ops.aten._cast_Long(val) - + + @register_test_case(module_factory=lambda: Aten_CastLongModule()) def Aten_CastLongModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) @@ -4319,22 +4659,26 @@ def Aten_CastLongModule_basic(module, tu: TestUtils): # ============================================================================== -class UpSampleNearest2dBackward(torch.nn.Module): +class UpSampleNearest2dBackward(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ] + ) def forward(self, input): - return torch.ops.aten.upsample_nearest2d_backward(input, - output_size=[6, 12], - input_size=[1, 1, 2, 3], - scales_h=3.0, - scales_w=4.0) + return torch.ops.aten.upsample_nearest2d_backward( + input, + output_size=[6, 12], + input_size=[1, 1, 2, 3], + scales_h=3.0, + scales_w=4.0, + ) @register_test_case(module_factory=lambda: UpSampleNearest2dBackward()) @@ -4343,21 +4687,25 @@ def UpSampleNearest2dBackward_basic(module, tu: TestUtils): class UpSampleNearest2dBackwardScalesNone(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, input): - return torch.ops.aten.upsample_nearest2d_backward(input, - output_size=[4, 8], - input_size=[1, 1, 2, 3], - scales_h=None, - scales_w=None) + return torch.ops.aten.upsample_nearest2d_backward( + input, + output_size=[4, 8], + input_size=[1, 1, 2, 3], + scales_h=None, + scales_w=None, + ) + @register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone()) def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils): @@ -4368,14 +4716,15 @@ def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils): class SortIntList(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = [1, 0, 3, 2] b = [0, 1, 2, 3] @@ -4389,14 +4738,15 @@ def SortIntList_basic(module, tu: TestUtils): class SortIntListReverse(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = [1, 0, 3, 2] b = [3, 2, 1, 0] @@ -4408,19 +4758,16 @@ class SortIntListReverse(torch.nn.Module): def SortIntListReverse_basic(module, tu: TestUtils): module.forward() + # ============================================================================== class SortTensor(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) def forward(self, input): return torch.sort(input) @@ -4429,16 +4776,13 @@ class SortTensor(torch.nn.Module): def SortTensor_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) -class SortTensorInteger(torch.nn.Module): +class SortTensorInteger(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True) - ]) + @annotate_args([None, ([-1, -1], torch.int64, True)]) def forward(self, input): return torch.sort(input) @@ -4449,15 +4793,11 @@ def SortTensorInteger_basic(module, tu: TestUtils): class SortTensorDescending(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) def forward(self, input): return torch.sort(input, descending=True) @@ -4466,16 +4806,13 @@ class SortTensorDescending(torch.nn.Module): def SortTensorDescending_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) -class SortTensorSpecificDimension(torch.nn.Module): +class SortTensorSpecificDimension(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) def forward(self, input): return torch.sort(input, dim=1) @@ -4484,16 +4821,13 @@ class SortTensorSpecificDimension(torch.nn.Module): def SortTensorSpecificDimension_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) -class SortTensorNegativeDimension(torch.nn.Module): +class SortTensorNegativeDimension(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) def forward(self, input): return torch.sort(input, dim=-1) @@ -4502,89 +4836,110 @@ class SortTensorNegativeDimension(torch.nn.Module): def SortTensorNegativeDimension_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class BucketizeTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, boundaries): return torch.bucketize(input, boundaries) + @register_test_case(module_factory=lambda: BucketizeTensorModule()) def BucketizeTensorModule_basic(module, tu: TestUtils): module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6])) + class BucketizeTensorOutInt32RightModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, boundaries): return torch.bucketize(input, boundaries, out_int32=True, right=True) + @register_test_case(module_factory=lambda: BucketizeTensorOutInt32RightModule()) def BucketizeTensorOutInt32RightModule_basic(module, tu: TestUtils): module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6])) + class BucketizeTensorFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, boundaries): return torch.bucketize(input, boundaries) + @register_test_case(module_factory=lambda: BucketizeTensorFloatModule()) def BucketizeTensorFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values) + class BucketizeTensorStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4], torch.int64, True), - ([3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 4], torch.int64, True), + ([3], torch.int64, True), + ] + ) def forward(self, input, boundaries): return torch.bucketize(input, boundaries) + @register_test_case(module_factory=lambda: BucketizeTensorStaticModule()) def BucketizeTensorStaticModule_basic(module, tu: TestUtils): module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6])) + class BucketizeTensorStaticFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([15, 17], torch.float32, True), - ([16], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([15, 17], torch.float32, True), + ([16], torch.float32, True), + ] + ) def forward(self, input, boundaries): return torch.bucketize(input, boundaries) + @register_test_case(module_factory=lambda: BucketizeTensorStaticFloatModule()) def BucketizeTensorStaticFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values) @@ -4592,16 +4947,18 @@ def BucketizeTensorStaticFloatModule_basic(module, tu: TestUtils): # ============================================================================== -class AtenFloatScalarModule(torch.nn.Module): +class AtenFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, x): a = torch.ops.aten.ScalarImplicit(x) return torch.ops.aten.Float(a) @@ -4616,14 +4973,13 @@ def AtenFloatScalarModule_basic(module, tu: TestUtils): class MoveDimIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) def forward(self, x): - return torch.ops.aten.movedim(x, source=1, destination=2) #0, 2, 1 + return torch.ops.aten.movedim(x, source=1, destination=2) # 0, 2, 1 @register_test_case(module_factory=lambda: MoveDimIntModule()) @@ -4635,7 +4991,6 @@ def MoveDimIntModule_basic(module, tu: TestUtils): class MoveDimIntNegativeIndexModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -4649,23 +5004,27 @@ class MoveDimIntNegativeIndexModule(torch.nn.Module): def MoveDimIntNegativeIndexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) + # ============================================================================== -class ScaledDotProductAttentionSameModule(torch.nn.Module): +class ScaledDotProductAttentionSameModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True) - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, query, key, value): return torch.ops.aten.scaled_dot_product_attention(query, key, value) + @register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule()) def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils): query = torch.randn(1, 5, 5, dtype=torch.float32) @@ -4673,21 +5032,24 @@ def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils): value = torch.randn(1, 5, 5, dtype=torch.float32) module.forward(query, key, value) -class ScaledDotProductAttentionDifferentModule(torch.nn.Module): +class ScaledDotProductAttentionDifferentModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 8, 4], torch.float32, True), - ([2, 3, 16, 4], torch.float32, True), - ([2, 3, 16, 4], torch.float32, True) - ]) + @annotate_args( + [ + None, + ([2, 3, 8, 4], torch.float32, True), + ([2, 3, 16, 4], torch.float32, True), + ([2, 3, 16, 4], torch.float32, True), + ] + ) def forward(self, query, key, value): return torch.ops.aten.scaled_dot_product_attention(query, key, value) + @register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule()) def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils): query = torch.randn(2, 3, 8, 4, dtype=torch.float32) @@ -4695,11 +5057,11 @@ def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils): value = torch.randn(2, 3, 16, 4, dtype=torch.float32) module.forward(query, key, value) + # ============================================================================== class PrimsViewOfModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -4715,7 +5077,6 @@ def PrimsViewOfModule_basic(module, tu: TestUtils): class PrimsViewOfZeroRankModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -4734,7 +5095,6 @@ def PrimsViewOfZeroRankModule_basic(module, tu: TestUtils): class OneHotModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -4753,16 +5113,16 @@ def OneHotModule_basic(module, tu: TestUtils): class ConstantBoolParameterModule(torch.nn.Module): - def __init__(self): super().__init__() - self.bool_tensor = torch.tensor( - [True, False, True, False], dtype=torch.bool) + self.bool_tensor = torch.tensor([True, False, True, False], dtype=torch.bool) @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return self.bool_tensor @@ -4776,14 +5136,15 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils): class ScalarTensorFloat32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): scalar = torch.ops.aten.scalar_tensor(1.0, dtype=torch.float32) return scalar @@ -4798,14 +5159,15 @@ def ScalarTensorFloat32Module_basic(module, tu: TestUtils): class ScalarTensorDefaultDtypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): scalar = torch.ops.aten.scalar_tensor(1.0) return scalar @@ -4820,14 +5182,15 @@ def ScalarTensorDefaultDtypeModule_basic(module, tu: TestUtils): class ScalarTensorInt64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): scalar = torch.ops.aten.scalar_tensor(1, dtype=torch.int64) return scalar @@ -4842,14 +5205,15 @@ def ScalarTensorInt64Module_basic(module, tu: TestUtils): class ScalarTensorInt32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): scalar = torch.ops.aten.scalar_tensor(1, dtype=torch.int32) return scalar @@ -4864,7 +5228,6 @@ def ScalarTensorInt32Module_basic(module, tu: TestUtils): class AtenTopKModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -4880,7 +5243,6 @@ def AtenTopKModule_basic(module, tu: TestUtils): class AtenTopKSmallestModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -4897,54 +5259,58 @@ def AtenTopKSmallestModule_basic(module, tu: TestUtils): # ============================================================================== -class AtenComplexImagModule(torch.nn.Module): +class AtenComplexImagModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.complex64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.complex64, True), + ] + ) def forward(self, x): return torch.ops.aten.imag(x) @register_test_case(module_factory=lambda: AtenComplexImagModule()) def AtenComplexImagModule_basic(module, tu: TestUtils): - module.forward(torch.view_as_complex(tu.rand(5,2))) + module.forward(torch.view_as_complex(tu.rand(5, 2))) class AtenComplexRealModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.complex64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.complex64, True), + ] + ) def forward(self, x): return torch.ops.aten.real(x) @register_test_case(module_factory=lambda: AtenComplexRealModule()) def AtenComplexRealModule_basic(module, tu: TestUtils): - module.forward(torch.view_as_complex(tu.rand(5,2))) + module.forward(torch.view_as_complex(tu.rand(5, 2))) class AtenComplex64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.complex64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.complex64, True), + ] + ) def forward(self, x): return x @@ -4955,22 +5321,24 @@ def AtenComplex64Module_basic(module, tu: TestUtils): class AtenComplexViewModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.view_as_complex(x) @register_test_case(module_factory=lambda: AtenComplexViewModule()) def AtenComplexViewModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5,2)) + module.forward(tu.rand(5, 2)) + # ============================================================================== class AtenRealView128Module(torch.nn.Module): @@ -4978,10 +5346,12 @@ class AtenRealView128Module(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.complex128, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex128, True), + ] + ) def forward(self, x): return torch.view_as_real(x) @@ -4990,16 +5360,19 @@ class AtenRealView128Module(torch.nn.Module): def AtenRealView128Module_basic(module, tu: TestUtils): module.forward(tu.rand(10, 6, 1).to(torch.complex128)) + # ============================================================================== class AtenRealView64Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.complex64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex64, True), + ] + ) def forward(self, x): return torch.view_as_real(x) @@ -5008,19 +5381,22 @@ class AtenRealView64Module(torch.nn.Module): def AtenRealView64Module_basic(module, tu: TestUtils): module.forward(tu.rand(10, 6, 1).to(torch.complex64)) + # ============================================================================== -class Add_Module(torch.nn.Module): +class Add_Module(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('tensor', torch.ones(2, 3)) + self.register_buffer("tensor", torch.ones(2, 3)) @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.add_(x, self.tensor) @@ -5034,16 +5410,17 @@ def Add_Module_basic(module, tu: TestUtils): class CosineSimilarityStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3], torch.float32, True), - ([2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ] + ) def forward(self, x1, x2): return torch.ops.aten.cosine_similarity(x1, x2) @@ -5057,16 +5434,17 @@ def CosineSimilarityStaticModule_basic(module, tu: TestUtils): class CosineSimilarityStaticBroadcastModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 2, 3], torch.float32, True), - ([4, 5, 1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 2, 3], torch.float32, True), + ([4, 5, 1, 1], torch.float32, True), + ] + ) def forward(self, x1, x2): return torch.ops.aten.cosine_similarity(x1, x2) @@ -5080,16 +5458,17 @@ def CosineSimilarityStaticBroadcastModule_basic(module, tu: TestUtils): class CosineSimilarityModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x1, x2): return torch.ops.aten.cosine_similarity(x1, x2) @@ -5103,16 +5482,17 @@ def CosineSimilarityModule_basic(module, tu: TestUtils): class IscloseStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 5], torch.float32, True), - ([5, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 5], torch.float32, True), + ([5, 5], torch.float32, True), + ] + ) def forward(self, x, y): return torch.isclose(x, y) @@ -5126,19 +5506,21 @@ def IscloseStaticModule_basic(module, tu: TestUtils): class IscloseStaticModuleTrue(torch.nn.Module): - def __init__(self): super().__init__() - self.register_buffer('tensor', torch.ones(1)) + self.register_buffer("tensor", torch.ones(1)) @export - @annotate_args([ - None, - ([5, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 5], torch.float32, True), + ] + ) def forward(self, x): return torch.isclose(x, self.tensor) + @register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) def IscloseStaticModuleTrue_basic(module, tu: TestUtils): module.forward(torch.ones(5, 5)) @@ -5146,20 +5528,22 @@ def IscloseStaticModuleTrue_basic(module, tu: TestUtils): # ============================================================================== -class CloneModule(torch.nn.Module): +class CloneModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 5], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.clone(x) + @register_test_case(module_factory=lambda: CloneModule()) def CloneModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5)) - diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/cast.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/cast.py index 613ba7e3b..207281161 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/cast.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/cast.py @@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class TensorToIntZeroRank(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, x): return int(x) @@ -28,17 +31,21 @@ class TensorToIntZeroRank(torch.nn.Module): def TensorToIntZeroRank_basic(module, tu: TestUtils): module.forward(tu.randint(high=10)) + # ============================================================================== + class TensorToInt(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return int(x) @@ -47,17 +54,21 @@ class TensorToInt(torch.nn.Module): def TensorToInt_basic(module, tu: TestUtils): module.forward(tu.randint(1, 1, high=10)) + # ============================================================================== + class TensorToFloatZeroRank(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ] + ) def forward(self, x): return float(x) @@ -66,17 +77,21 @@ class TensorToFloatZeroRank(torch.nn.Module): def TensorToFloatZeroRank_basic(module, tu: TestUtils): module.forward(tu.rand().to(torch.float64)) + # ============================================================================== + class TensorToFloat(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x): return float(x) @@ -85,17 +100,21 @@ class TensorToFloat(torch.nn.Module): def TensorToFloat_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1).to(torch.float64)) + # ============================================================================== + class TensorToBoolZeroRank(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([], torch.bool, True), + ] + ) def forward(self, x): return bool(x) @@ -104,17 +123,21 @@ class TensorToBoolZeroRank(torch.nn.Module): def TensorToBoolZeroRank_basic(module, tu: TestUtils): module.forward(torch.tensor(1, dtype=torch.bool)) + # ============================================================================== + class TensorToBool(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x): return bool(x) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 540fa2d22..8ce0a44d7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -13,14 +13,15 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class ZerosModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.zeros(3, 4) @@ -31,14 +32,15 @@ def ZerosModuleDefaultDtype_basic(module, tu: TestUtils): class ZerosModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.zeros(3, 4, dtype=torch.int64) @@ -49,14 +51,15 @@ def ZerosModuleInt2D_basic(module, tu: TestUtils): class ZerosModuleInt3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.zeros(3, 4, 5, dtype=torch.int64) @@ -67,14 +70,15 @@ def ZerosModuleInt3D_basic(module, tu: TestUtils): class ZerosModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.zeros(3, 4, dtype=torch.float32) @@ -85,14 +89,15 @@ def ZerosModuleFloat2D_basic(module, tu: TestUtils): class ZerosModuleFloat3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.zeros(3, 4, 5, dtype=torch.float32) @@ -103,14 +108,15 @@ def ZerosModuleFloat3D_basic(module, tu: TestUtils): class ZerosModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False) @@ -124,14 +130,15 @@ def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils): class OnesModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ones(3, 4) @@ -142,14 +149,15 @@ def OnesModuleDefaultDtype_basic(module, tu: TestUtils): class OnesModuleInt(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ones(3, 4, dtype=torch.int64) @@ -160,14 +168,15 @@ def OnesModuleInt_basic(module, tu: TestUtils): class OnesModuleFloat(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ones(3, 4, dtype=torch.float32) @@ -178,14 +187,15 @@ def OnesModuleFloat_basic(module, tu: TestUtils): class OnesModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ones(3, 4, dtype=torch.float32, pin_memory=False) @@ -196,14 +206,15 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils): class OnesModuleCPUDevice(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ones(3, 4, device="cpu") @@ -215,15 +226,17 @@ def OnesModuleCPUDevice_basic(module, tu: TestUtils): # ============================================================================== -class AtenEyeModuleDefaultDtype(torch.nn.Module): +class AtenEyeModuleDefaultDtype(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.eye(3) @@ -234,14 +247,15 @@ def AtenEyeModuleDefaultDtype_basic(module, tu: TestUtils): class AtenEyeModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, dtype=torch.int64) @@ -252,14 +266,15 @@ def AtenEyeModuleInt2D_basic(module, tu: TestUtils): class AtenEyeModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, dtype=torch.float32) @@ -268,15 +283,17 @@ class AtenEyeModuleFloat2D(torch.nn.Module): def AtenEyeModuleFloat2D_basic(module, tu: TestUtils): module.forward() -class AtenEyeModuleFalsePinMemory(torch.nn.Module): +class AtenEyeModuleFalsePinMemory(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, dtype=torch.float32, pin_memory=False) @@ -287,14 +304,15 @@ def AtenEyeModuleFalsePinMemory_basic(module, tu: TestUtils): class AtenEyeModuleCPUDevice(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, device="cpu") @@ -303,17 +321,20 @@ class AtenEyeModuleCPUDevice(torch.nn.Module): def AtenEyeModuleCPUDevice_basic(module, tu: TestUtils): module.forward() + # ============================================================================== -class AtenEyeMModuleDefaultDtype(torch.nn.Module): +class AtenEyeMModuleDefaultDtype(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.eye(3, 4) @@ -324,14 +345,15 @@ def AtenEyeMModuleDefaultDtype_basic(module, tu: TestUtils): class AtenEyeMModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, 4, dtype=torch.int64) @@ -342,14 +364,15 @@ def AtenEyeMModuleInt2D_basic(module, tu: TestUtils): class AtenEyeMModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, 4, dtype=torch.float32) @@ -358,15 +381,17 @@ class AtenEyeMModuleFloat2D(torch.nn.Module): def AtenEyeMModuleFloat2D_basic(module, tu: TestUtils): module.forward() -class AtenEyeMModuleFalsePinMemory(torch.nn.Module): +class AtenEyeMModuleFalsePinMemory(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, 4, dtype=torch.float32, pin_memory=False) @@ -377,14 +402,15 @@ def AtenEyeMModuleFalsePinMemory_basic(module, tu: TestUtils): class AtenEyeMModuleCPUDevice(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.eye(3, 4, device="cpu") @@ -393,21 +419,22 @@ class AtenEyeMModuleCPUDevice(torch.nn.Module): def AtenEyeMModuleCPUDevice_basic(module, tu: TestUtils): module.forward() + # ============================================================================== class EmptyContiguousModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): - return torch.empty((3, 4), - memory_format=torch.contiguous_format).fill_(0) + return torch.empty((3, 4), memory_format=torch.contiguous_format).fill_(0) @register_test_case(module_factory=lambda: EmptyContiguousModule()) @@ -416,14 +443,15 @@ def EmptyModule_contiguous(module, tu: TestUtils): class EmptyDefaultDtypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.empty((3, 4)).fill_(0) @@ -434,14 +462,15 @@ def EmptyModule_defaultDtype(module, tu: TestUtils): class EmptyIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.empty((3, 4), dtype=torch.int64).fill_(0) @@ -452,14 +481,15 @@ def EmptyModule_int(module, tu: TestUtils): class EmptyUInt8Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): empty = torch.ops.aten.empty([1], dtype=torch.uint8) return torch.ops.aten.zeros_like(empty).to(torch.int8) @@ -471,14 +501,15 @@ def EmptyModule_uint8(module, tu: TestUtils): class EmptyFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.empty((3, 4), dtype=torch.float32).fill_(0) @@ -489,17 +520,17 @@ def EmptyModule_float(module, tu: TestUtils): class EmptyFalsePinMemoryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): - return torch.empty((3, 4), dtype=torch.float32, - pin_memory=False).fill_(0) + return torch.empty((3, 4), dtype=torch.float32, pin_memory=False).fill_(0) @register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule()) @@ -511,15 +542,16 @@ def EmptyModule_falsePinMemory(module, tu: TestUtils): class EmptyLikeDefaultDtypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.empty_like(a).fill_(0) @@ -530,15 +562,16 @@ def EmptyLikeModule_defaultDtype(module, tu: TestUtils): class EmptyLikeIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.empty_like(a, dtype=torch.int32).fill_(0) @@ -549,18 +582,18 @@ def EmptyLikeModule_int(module, tu: TestUtils): class EmptyLikeMemoryFormatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): - return torch.empty_like(a, - memory_format=torch.preserve_format).fill_(0) + return torch.empty_like(a, memory_format=torch.preserve_format).fill_(0) @register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule()) @@ -569,15 +602,16 @@ def EmptyLikeMemoryFormatModule_basic(module, tu: TestUtils): class EmptyLikeFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.empty_like(a, dtype=torch.float32).fill_(0) @@ -588,18 +622,18 @@ def EmptyLikeModule_float(module, tu: TestUtils): class EmptyLikeFalsePinMemoryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): - return torch.empty_like(a, dtype=torch.float64, - pin_memory=False).fill_(0) + return torch.empty_like(a, dtype=torch.float64, pin_memory=False).fill_(0) @register_test_case(module_factory=lambda: EmptyLikeFalsePinMemoryModule()) @@ -611,15 +645,16 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils): class ZerosLikeDefaultDtypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.zeros_like(a) @@ -630,15 +665,16 @@ def ZerosLikeModule_defaultDtype(module, tu: TestUtils): class ZerosLikeIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.zeros_like(a, dtype=torch.int32) @@ -649,15 +685,16 @@ def ZerosLikeModule_int(module, tu: TestUtils): class ZerosLikeFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.zeros_like(a, dtype=torch.float32) @@ -668,15 +705,16 @@ def ZerosLikeModule_float(module, tu: TestUtils): class ZerosLikeFalsePinMemoryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.zeros_like(a, dtype=torch.float64, pin_memory=False) @@ -690,15 +728,16 @@ def ZerosLikeModule_falsePinMemory(module, tu: TestUtils): class OnesLikeDefaultDtypeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ones_like(a) @@ -709,15 +748,16 @@ def OnesLikeModule_defaultDtype(module, tu: TestUtils): class OnesLikeIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ones_like(a, dtype=torch.int32) @@ -728,15 +768,16 @@ def OnesLikeModule_int(module, tu: TestUtils): class OnesLikeFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ones_like(a, dtype=torch.float32) @@ -747,15 +788,16 @@ def OnesLikeModule_float(module, tu: TestUtils): class OnesLikeFalsePinMemoryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ones_like(a, dtype=torch.float64, pin_memory=False) @@ -769,15 +811,16 @@ def OnesLikeModule_falsePinMemory(module, tu: TestUtils): class NewZerosModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4]) @@ -788,15 +831,16 @@ def NewZerosModuleDefaultDtype_basic(module, tu: TestUtils): class NewZerosModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.int64) @@ -807,15 +851,16 @@ def NewZerosModuleInt2D_basic(module, tu: TestUtils): class NewZerosModuleInt3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.int64) @@ -826,15 +871,16 @@ def NewZerosModuleInt3D_basic(module, tu: TestUtils): class NewZerosModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32) @@ -845,15 +891,16 @@ def NewZerosModuleFloat2D_basic(module, tu: TestUtils): class NewZerosModuleFloat3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.float32) @@ -864,19 +911,20 @@ def NewZerosModuleFloat3D_basic(module, tu: TestUtils): class NewZerosModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_zeros(a, [3, 4], - dtype=torch.float32, - pin_memory=False) + return torch.ops.aten.new_zeros( + a, [3, 4], dtype=torch.float32, pin_memory=False + ) @register_test_case(module_factory=lambda: NewZerosModuleFalsePinMemory()) @@ -885,15 +933,16 @@ def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils): class NewZerosStaticModuleLayoutStrided(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([1, 4], torch.int64, True), + ] + ) def forward(self, a): return a.new_zeros(a.shape) @@ -902,19 +951,21 @@ class NewZerosStaticModuleLayoutStrided(torch.nn.Module): def NewZerosStaticModuleLayoutStrided_basic(module, tu: TestUtils): module.forward(tu.randint(1, 4, high=10)) + # ============================================================================== class NewOnesModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4]) @@ -925,15 +976,16 @@ def NewOnesModuleDefaultDtype_basic(module, tu: TestUtils): class NewOnesModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.int64) @@ -944,15 +996,16 @@ def NewOnesModuleInt2D_basic(module, tu: TestUtils): class NewOnesModuleInt3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.int64) @@ -963,15 +1016,16 @@ def NewOnesModuleInt3D_basic(module, tu: TestUtils): class NewOnesModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32) @@ -982,15 +1036,16 @@ def NewOnesModuleFloat2D_basic(module, tu: TestUtils): class NewOnesModuleFloat3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.float32) @@ -1001,19 +1056,18 @@ def NewOnesModuleFloat3D_basic(module, tu: TestUtils): class NewOnesModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_ones(a, [3, 4], - dtype=torch.float32, - pin_memory=False) + return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32, pin_memory=False) @register_test_case(module_factory=lambda: NewOnesModuleFalsePinMemory()) @@ -1025,14 +1079,15 @@ def NewOnesModuleFalsePinMemory_basic(module, tu: TestUtils): class FullModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.full([2, 3], 5.0) @@ -1043,14 +1098,15 @@ def FullModuleDefaultDtype_basic(module, tu: TestUtils): class FullModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.full([10, 5], 10.5, dtype=torch.int64) @@ -1061,14 +1117,15 @@ def FullModuleInt2D_basic(module, tu: TestUtils): class FullModuleInt3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.full([2, 3, 4], 5) @@ -1079,14 +1136,15 @@ def FullModuleInt3D_basic(module, tu: TestUtils): class FullModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.full([10, 5], 10, dtype=torch.float32) @@ -1097,14 +1155,15 @@ def FullModuleFloat2D_basic(module, tu: TestUtils): class FullModuleFloat3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.ops.aten.full([2, 3, 4], 5.0) @@ -1115,19 +1174,17 @@ def FullModuleFloat3D_basic(module, tu: TestUtils): class FullModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): - return torch.ops.aten.full([2, 3], - 5.0, - dtype=torch.int64, - pin_memory=False) + return torch.ops.aten.full([2, 3], 5.0, dtype=torch.int64, pin_memory=False) @register_test_case(module_factory=lambda: FullModuleFalsePinMemory()) @@ -1139,15 +1196,16 @@ def FullModuleFalsePinMemory_basic(module, tu: TestUtils): class FullLikeModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.full_like(a, 5) @@ -1158,15 +1216,16 @@ def FullLikeModuleDefaultDtype_basic(module, tu: TestUtils): class FullLikeModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.full_like(a, 10.5) @@ -1177,15 +1236,16 @@ def FullLikeModuleInt2D_basic(module, tu: TestUtils): class FullLikeModuleInt3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.ops.aten.full_like(a, 5.0, dtype=torch.int64) @@ -1196,15 +1256,16 @@ def FullLikeModuleInt3D_basic(module, tu: TestUtils): class FullLikeModuleInt2DStatic(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([4, 5], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.full_like(a, 10) @@ -1215,15 +1276,16 @@ def FullLikeModuleInt2DStatic_basic(module, tu: TestUtils): class FullLikeModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.full_like(a, 10) @@ -1234,15 +1296,16 @@ def FullLikeModuleFloat2D_basic(module, tu: TestUtils): class FullLikeModuleFloat3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.full_like(a, 15, dtype=torch.float32) @@ -1253,15 +1316,16 @@ def FullLikeModuleFloat3D_basic(module, tu: TestUtils): class FullLikeModuleFloat3DStatic(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.full_like(a, 15.3, dtype=torch.float32) @@ -1272,41 +1336,41 @@ def FullLikeModuleFloat3DStatic_basic(module, tu: TestUtils): class FullLikeModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.full_like(a, - 5, - dtype=torch.int64, - pin_memory=False) + return torch.ops.aten.full_like(a, 5, dtype=torch.int64, pin_memory=False) @register_test_case(module_factory=lambda: FullLikeModuleFalsePinMemory()) def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward(tu.randint(10, 4, high=100)) + # ============================================================================== class NewFullModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_full(a, (3,4), 5) + return torch.ops.aten.new_full(a, (3, 4), 5) @register_test_case(module_factory=lambda: NewFullModuleDefaultDtype()) @@ -1315,17 +1379,18 @@ def NewFullModuleDefaultDtype_basic(module, tu: TestUtils): class NewFullModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_full(a, (3,4), 10.5) + return torch.ops.aten.new_full(a, (3, 4), 10.5) @register_test_case(module_factory=lambda: NewFullModuleInt2D()) @@ -1334,17 +1399,18 @@ def NewFullModuleInt2D_basic(module, tu: TestUtils): class NewFullModuleInt3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_full(a, (3,4), 5.0, dtype=torch.int64) + return torch.ops.aten.new_full(a, (3, 4), 5.0, dtype=torch.int64) @register_test_case(module_factory=lambda: NewFullModuleInt3D()) @@ -1353,17 +1419,18 @@ def NewFullModuleInt3D_basic(module, tu: TestUtils): class NewFullModuleFloat3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_full(a, (3,4), 15, dtype=torch.float32) + return torch.ops.aten.new_full(a, (3, 4), 15, dtype=torch.float32) @register_test_case(module_factory=lambda: NewFullModuleFloat3D()) @@ -1372,17 +1439,18 @@ def NewFullModuleFloat3D_basic(module, tu: TestUtils): class NewFullModuleFloat3DStatic(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_full(a, (3,4), 15.3, dtype=torch.float32) + return torch.ops.aten.new_full(a, (3, 4), 15.3, dtype=torch.float32) @register_test_case(module_factory=lambda: NewFullModuleFloat3DStatic()) @@ -1391,21 +1459,20 @@ def NewFullModuleFloat3DStatic_basic(module, tu: TestUtils): class NewFullModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_full(a, - (3,4), - 5, - dtype=torch.int64, - pin_memory=False) + return torch.ops.aten.new_full( + a, (3, 4), 5, dtype=torch.int64, pin_memory=False + ) @register_test_case(module_factory=lambda: NewFullModuleFalsePinMemory()) @@ -1417,15 +1484,16 @@ def NewFullModuleFalsePinMemory_basic(module, tu: TestUtils): class ZeroFloat32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, tensor): return torch.ops.aten.zero_(tensor) @@ -1436,15 +1504,16 @@ def ZeroFloat32Module_basic(module, tu: TestUtils): class ZeroInt32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, tensor): return torch.ops.aten.zero_(tensor) @@ -1455,15 +1524,16 @@ def ZeroInt32Module_basic(module, tu: TestUtils): class ZeroInt64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, tensor): return torch.ops.aten.zero_(tensor) @@ -1477,15 +1547,16 @@ def ZeroInt64Module_basic(module, tu: TestUtils): class NewEmptyModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) @@ -1496,15 +1567,16 @@ def NewEmptyModuleDefaultDtype_basic(module, tu: TestUtils): class NewEmptyModuleInt2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0) @@ -1515,18 +1587,18 @@ def NewEmptyModuleInt2D_basic(module, tu: TestUtils): class NewEmptyModuleInt3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4, 5], - dtype=torch.int64).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.int64).fill_(0) @register_test_case(module_factory=lambda: NewEmptyModuleInt3D()) @@ -1535,18 +1607,18 @@ def NewEmptyModuleInt3D_basic(module, tu: TestUtils): class NewEmptyModuleFloat2D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4], - dtype=torch.float32).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32).fill_(0) @register_test_case(module_factory=lambda: NewEmptyModuleFloat2D()) @@ -1555,18 +1627,18 @@ def NewEmptyModuleFloat2D_basic(module, tu: TestUtils): class NewEmptyModuleFloat3D(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4, 5], - dtype=torch.float32).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.float32).fill_(0) @register_test_case(module_factory=lambda: NewEmptyModuleFloat3D()) @@ -1575,19 +1647,20 @@ def NewEmptyModuleFloat3D_basic(module, tu: TestUtils): class NewEmptyModuleFalsePinMemory(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4], - dtype=torch.float32, - pin_memory=False).fill_(0) + return torch.ops.aten.new_empty( + a, [3, 4], dtype=torch.float32, pin_memory=False + ).fill_(0) @register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory()) @@ -1596,35 +1669,36 @@ def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils): class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) -@register_test_case( - module_factory=lambda: NewEmptyModuleNonDefaultFloatDtype()) +@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultFloatDtype()) def NewEmptyModuleNonDefaultFloatDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3).to(torch.float64)) class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) @@ -1635,15 +1709,16 @@ def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils): class NewEmptyModuleLayoutIntDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4], layout=0).fill_(0) @@ -1657,167 +1732,180 @@ def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils): class MaskedFillScalarDefaultModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x, mask): return torch.ops.aten.masked_fill(x, mask, value=0.5) @register_test_case(module_factory=lambda: MaskedFillScalarDefaultModule()) def MaskedFillScalarDefaultModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), - tu.randint(2, 3, high=2).to(dtype=torch.bool)) + module.forward(tu.rand(2, 3), tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillScalarIntValueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x, mask): return torch.ops.aten.masked_fill(x, mask, value=5) @register_test_case(module_factory=lambda: MaskedFillScalarIntValueModule()) def MaskedFillScalarIntValueModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), - tu.randint(2, 3, high=2).to(dtype=torch.bool)) + module.forward(tu.rand(2, 3), tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillScalarFloatValueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x, mask): return torch.ops.aten.masked_fill(x, mask, value=-0.01) @register_test_case(module_factory=lambda: MaskedFillScalarFloatValueModule()) def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3, low=-10, high=10), - tu.randint(2, 3, high=2).to(dtype=torch.bool)) + module.forward( + tu.randint(2, 3, low=-10, high=10), + tu.randint(2, 3, high=2).to(dtype=torch.bool), + ) class MaskedFillScalarFloatValueStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3], torch.int64, True), - ([2, 3], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.bool, True), + ] + ) def forward(self, x, mask): return torch.ops.aten.masked_fill(x, mask, value=-0.01) @register_test_case(module_factory=lambda: MaskedFillScalarFloatValueStaticModule()) def MaskedFillScalarFloatValueStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3, low=-10, high=10), - tu.randint(2, 3, high=2).to(dtype=torch.bool)) + module.forward( + tu.randint(2, 3, low=-10, high=10), + tu.randint(2, 3, high=2).to(dtype=torch.bool), + ) class MaskedFillTensorFloatValueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.bool, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.bool, True), + ([], torch.float32, True), + ] + ) def forward(self, x, mask, value): return torch.ops.aten.masked_fill(x, mask, value=value) @register_test_case(module_factory=lambda: MaskedFillTensorFloatValueModule()) def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3, low=-10, high=10), - tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand()) + module.forward( + tu.randint(2, 3, low=-10, high=10), + tu.randint(2, 3, high=2).to(dtype=torch.bool), + tu.rand(), + ) class MaskedFillScalarIntValueStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3], torch.int64, True), - ([2, 3], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.bool, True), + ] + ) def forward(self, x, mask): return torch.ops.aten.masked_fill(x, mask, value=5) @register_test_case(module_factory=lambda: MaskedFillScalarIntValueStaticModule()) def MaskedFillScalarIntValueStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3), - tu.randint(2, 3, high=2).to(dtype=torch.bool)) + module.forward(tu.randint(2, 3), tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillTensorIntValueStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3], torch.int64, True), - ([2, 3], torch.bool, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.bool, True), + ([], torch.int64, True), + ] + ) def forward(self, x, mask, value): return torch.ops.aten.masked_fill(x, mask, value=value) @register_test_case(module_factory=lambda: MaskedFillTensorIntValueStaticModule()) def MaskedFillTensorIntValueStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3), - tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.randint()) + module.forward( + tu.randint(2, 3), tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.randint() + ) # ============================================================================== class NewEmptyStridedModuleDefaultDtype(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4], torch.float32, True), + ] + ) def forward(self, a): x = torch.ops.aten.new_empty_strided(a, size=[2, 3, 4], stride=[12, 4, 1]) y = x.copy_(a) @@ -1828,19 +1916,21 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module): def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== class EmptyStridedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4], torch.float32, True), + ] + ) def forward(self, a): x = torch.ops.aten.empty_strided(a.size(), stride=[12, 4, 1]) y = x.copy_(a) @@ -1851,19 +1941,21 @@ class EmptyStridedModule(torch.nn.Module): def EmptyStridedModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3, 4)) + # ============================================================================== class EmptyStridedSizeIntStrideModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, -1], torch.float32, True), + ] + ) def forward(self, a): x = torch.ops.aten.empty_strided(a.size(), stride=[12, a.size(2), 1]) y = x.copy_(a) @@ -1874,22 +1966,23 @@ class EmptyStridedSizeIntStrideModule(torch.nn.Module): def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== class AtenDiagEmbedDefaultDiag(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diag_embed(a) - @register_test_case(module_factory=lambda: AtenDiagEmbedDefaultDiag()) def AtenDiagEmbedDefaultDiag_basic(module, tu: TestUtils): @@ -1897,18 +1990,18 @@ class AtenDiagEmbedDefaultDiag(torch.nn.Module): class AtenDiagEmbedDimDiag(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diag_embed(a, offset=0, dim1=1, dim2=3) - @register_test_case(module_factory=lambda: AtenDiagEmbedDimDiag()) def AtenDiagEmbedDimDiag_basic(module, tu: TestUtils): @@ -1916,18 +2009,18 @@ class AtenDiagEmbedDimDiag(torch.nn.Module): class AtenDiagEmbedOffsetDiag(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diag_embed(a, offset=1, dim1=1, dim2=3) - @register_test_case(module_factory=lambda: AtenDiagEmbedOffsetDiag()) def AtenDiagEmbedOffsetDiag_basic(module, tu: TestUtils): @@ -1935,18 +2028,18 @@ class AtenDiagEmbedOffsetDiag(torch.nn.Module): class AtenDiagEmbedRevDimDiag(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diag_embed(a, offset=1, dim1=3, dim2=1) - @register_test_case(module_factory=lambda: AtenDiagEmbedRevDimDiag()) def AtenDiagEmbedRevDimDiag_basic(module, tu: TestUtils): @@ -1954,37 +2047,38 @@ class AtenDiagEmbedRevDimDiag(torch.nn.Module): class AtenDiagEmbedNegOffsetDiag(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diag_embed(a, offset=-1, dim1=1, dim2=3) - @register_test_case(module_factory=lambda: AtenDiagEmbedNegOffsetDiag()) def AtenDiagEmbedNegOffsetDiag_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) -class AtenDiagEmbedNonDefault4DDiag(torch.nn.Module): +class AtenDiagEmbedNonDefault4DDiag(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diag_embed(a, offset=-2, dim1=1, dim2=-3) - @register_test_case(module_factory=lambda: AtenDiagEmbedNonDefault4DDiag()) def AtenDiagEmbedNonDefault4DDiag_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5)) \ No newline at end of file + module.forward(tu.rand(2, 3, 4, 5)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py index d40a77bb6..a04114043 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -13,15 +13,13 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class TorchPrimLoopForLikeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True) - ]) + @annotate_args([None, ([-1, -1], torch.int64, True)]) def forward(self, x): x_val = x.size(0) sum = 0 @@ -34,20 +32,18 @@ class TorchPrimLoopForLikeModule(torch.nn.Module): def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils): module.forward(tu.randint(6, 8, high=10)) + # ============================================================================== class TorchPrimLoopWhileLikeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True) - ]) + @annotate_args([None, ([-1, -1], torch.int64, True)]) def forward(self, x): x_val = x.size(0) sum = 0 - while(x_val > sum): + while x_val > sum: sum += 1 return sum @@ -59,20 +55,24 @@ def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils): # ============================================================================== + class TorchPrimLoopForLikeTensorArgModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([7,9], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([7, 9], torch.float32, True), + ] + ) def forward(self, x: torch.Tensor) -> torch.Tensor: for i in range(50): x = x + i return x + @register_test_case(module_factory=lambda: TorchPrimLoopForLikeTensorArgModule()) def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils): x_test = torch.zeros([7, 9]).float() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 5872df170..9600b0900 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -12,7 +12,6 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class Conv2dNoPaddingModule(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) @@ -20,10 +19,12 @@ class Conv2dNoPaddingModule(torch.nn.Module): self.train(False) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.conv(x) @@ -35,7 +36,6 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils): class Conv2dBiasNoPaddingModule(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) @@ -43,10 +43,12 @@ class Conv2dBiasNoPaddingModule(torch.nn.Module): self.train(False) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.conv(x) @@ -58,7 +60,6 @@ def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils): class Conv2dWithPaddingModule(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) @@ -66,10 +67,12 @@ class Conv2dWithPaddingModule(torch.nn.Module): self.train(False) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.conv(x) @@ -81,111 +84,141 @@ def Conv2dWithPaddingModule_basic(module, tu: TestUtils): class Conv2dWithPaddingDilationStrideModule(torch.nn.Module): - def __init__(self): super().__init__() torch.manual_seed(0) - self.conv = torch.nn.Conv2d(in_channels=2, - out_channels=10, - kernel_size=3, - padding=3, - stride=2, - dilation=3, - bias=False) + self.conv = torch.nn.Conv2d( + in_channels=2, + out_channels=10, + kernel_size=3, + padding=3, + stride=2, + dilation=3, + bias=False, + ) self.train(False) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.conv(x) -@register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideModule()) +@register_test_case(module_factory=lambda: Conv2dWithPaddingDilationStrideModule()) def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils): t = tu.rand(5, 2, 10, 20) module.forward(t) class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module): - def __init__(self, out_channels, groups): super().__init__() torch.manual_seed(0) - self.conv = torch.nn.Conv2d(in_channels=4, - out_channels=out_channels, - kernel_size=3, - padding=3, - stride=2, - dilation=3, - bias=False, - groups=groups) + self.conv = torch.nn.Conv2d( + in_channels=4, + out_channels=out_channels, + kernel_size=3, + padding=3, + stride=2, + dilation=3, + bias=False, + groups=groups, + ) self.train(False) @export - @annotate_args([ - None, - ([5, 4, 10, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 4, 10, 20], torch.float32, True), + ] + ) def forward(self, x): return self.conv(x) @register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=10, groups=1)) + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule( + out_channels=10, groups=1 + ) +) def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 4, 10, 20)) @register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=4)) + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule( + out_channels=4, groups=4 + ) +) def Conv2dWithPaddingDilationStrideStaticModule_depthwise(module, tu: TestUtils): module.forward(tu.rand(5, 4, 10, 20)) @register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=4)) -def Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier(module, tu: TestUtils): + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule( + out_channels=8, groups=4 + ) +) +def Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier( + module, tu: TestUtils +): module.forward(tu.rand(5, 4, 10, 20)) @register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=2)) + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule( + out_channels=4, groups=2 + ) +) def Conv2dWithPaddingDilationStrideStaticModule_grouped(module, tu: TestUtils): module.forward(tu.rand(5, 4, 10, 20)) @register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=2)) -def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(module, tu: TestUtils): + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule( + out_channels=8, groups=2 + ) +) +def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier( + module, tu: TestUtils +): module.forward(tu.rand(5, 4, 10, 20)) # ============================================================================== + class Convolution2DModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + ) + @register_test_case(module_factory=lambda: Convolution2DModule()) def Convolution2DModule_basic(module, tu: TestUtils): @@ -197,462 +230,559 @@ class Convolution2DStaticModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([3, 3, 10, 10], torch.float32, True), - ([3, 3, 2, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 3, 10, 10], torch.float32, True), + ([3, 3, 2, 2], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + ) + @register_test_case(module_factory=lambda: Convolution2DStaticModule()) def Convolution2DStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class Convolution2DStridedModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + ) + @register_test_case(module_factory=lambda: Convolution2DStridedModule()) def Convolution2DStridedModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _Convolution2DAllFalseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=False, - deterministic=False, - cudnn_enabled=False, - allow_tf32=False) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=False, + allow_tf32=False, + ) + @register_test_case(module_factory=lambda: _Convolution2DAllFalseModule()) def _Convolution2DAllFalseModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _Convolution2DBenchmarkModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=True, - deterministic=False, - cudnn_enabled=False, - allow_tf32=False) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=True, + deterministic=False, + cudnn_enabled=False, + allow_tf32=False, + ) + @register_test_case(module_factory=lambda: _Convolution2DBenchmarkModule()) def _Convolution2DBenchmarkModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _Convolution2DDeterministicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=False, - deterministic=True, - cudnn_enabled=False, - allow_tf32=False) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=True, + cudnn_enabled=False, + allow_tf32=False, + ) + @register_test_case(module_factory=lambda: _Convolution2DDeterministicModule()) def _Convolution2DDeterministicModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _Convolution2DCudnnModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=False, - deterministic=False, - cudnn_enabled=True, - allow_tf32=False) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=True, + allow_tf32=False, + ) + @register_test_case(module_factory=lambda: _Convolution2DCudnnModule()) def _Convolution2DCudnnModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _Convolution2DTF32Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=False, - deterministic=False, - cudnn_enabled=False, - allow_tf32=True) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=False, + allow_tf32=True, + ) + @register_test_case(module_factory=lambda: _Convolution2DTF32Module()) def _Convolution2DTF32Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=False, - deterministic=False, - cudnn_enabled=False) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=False, + ) + @register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule()) def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=True, - deterministic=False, - cudnn_enabled=False) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=True, + deterministic=False, + cudnn_enabled=False, + ) + @register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule()) def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=False, - deterministic=True, - cudnn_enabled=False) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=True, + cudnn_enabled=False, + ) -@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule()) + +@register_test_case( + module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule() +) def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten._convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - benchmark=False, - deterministic=False, - cudnn_enabled=True) + return torch.ops.aten._convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=True, + ) + @register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule()) def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + class ConvolutionModule2DGroups(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[3, 3], - padding=[2, 2], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=4) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=4, + ) + @register_test_case(module_factory=lambda: ConvolutionModule2DGroups()) def ConvolutionModule2DGroups_basic(module, tu: TestUtils): module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 8, 3, 3)) + # ============================================================================== -class ConvolutionModule2DTranspose(torch.nn.Module): +class ConvolutionModule2DTranspose(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[1, 1], - padding=[1, 1], - dilation=[1, 1], - transposed=True, - output_padding=[0, 0], - groups=1) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) @register_test_case(module_factory=lambda: ConvolutionModule2DTranspose()) def ConvolutionModule2DTranspose_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 4, 4), tu.rand(3, 3, 2, 2)) -class ConvolutionModule2DTransposeStrided(torch.nn.Module): +class ConvolutionModule2DTransposeStrided(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1], - transposed=True, - output_padding=[0, 0], - groups=1) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) @register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided()) def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) -class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module): +class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 2, 5, 6], torch.float32, True), - ([2, 5, 2, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1], - transposed=True, - output_padding=[0, 0], - groups=1) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) @register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic()) def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + class ConvolutionModule2DTransposeNonUnitOutputPadding(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.convolution(inputVec, - weight, - bias=None, - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1], - transposed=True, - output_padding=[1, 1], - groups=1) + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[1, 1], + groups=1, + ) -@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeNonUnitOutputPadding()) + +@register_test_case( + module_factory=lambda: ConvolutionModule2DTransposeNonUnitOutputPadding() +) def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3)) class Conv_Transpose2dModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec, weight): - return torch.ops.aten.conv_transpose2d(inputVec, - weight, - bias=None, - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1], - output_padding=[0, 0], - groups=1) + return torch.ops.aten.conv_transpose2d( + inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1, + ) @register_test_case(module_factory=lambda: Conv_Transpose2dModule()) @@ -661,41 +791,42 @@ def Conv_Transpose2dModule_basic(module, tu: TestUtils): class UpSampleNearest2d(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ] + ) def forward(self, input): - return torch.ops.aten.upsample_nearest2d(input, - output_size=[18, 48], - scales_h=3.0, - scales_w=4.0) + return torch.ops.aten.upsample_nearest2d( + input, output_size=[18, 48], scales_h=3.0, scales_w=4.0 + ) @register_test_case(module_factory=lambda: UpSampleNearest2d()) def UpSampleNearest2d_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) -class UpSampleNearest2dSameSize(torch.nn.Module): +class UpSampleNearest2dSameSize(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec): - return torch._C._nn.upsample_nearest2d(inputVec, - output_size=[11, 11], - scales_h=None, - scales_w=None) + return torch._C._nn.upsample_nearest2d( + inputVec, output_size=[11, 11], scales_h=None, scales_w=None + ) @register_test_case(module_factory=lambda: UpSampleNearest2dSameSize()) @@ -704,17 +835,15 @@ def UpSampleNearest2dStaticSize_basic(module, tu: TestUtils): class UpSampleNearest2dDiffSize(torch.nn.Module): - def __init__(self): super().__init__() @export @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) def forward(self, inputVec): - return torch._C._nn.upsample_nearest2d(inputVec, - output_size=[8, 11], - scales_h=None, - scales_w=None) + return torch._C._nn.upsample_nearest2d( + inputVec, output_size=[8, 11], scales_h=None, scales_w=None + ) @register_test_case(module_factory=lambda: UpSampleNearest2dDiffSize()) @@ -723,17 +852,15 @@ def UpSampleNearest2dDynamicSize_basic(module, tu: TestUtils): class UpSampleNearest2dDiffFactor(torch.nn.Module): - def __init__(self): super().__init__() @export @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) def forward(self, inputVec): - return torch._C._nn.upsample_nearest2d(inputVec, - output_size=[6, 10], - scales_h=2.3, - scales_w=4.7) + return torch._C._nn.upsample_nearest2d( + inputVec, output_size=[6, 10], scales_h=2.3, scales_w=4.7 + ) @register_test_case(module_factory=lambda: UpSampleNearest2dDiffFactor()) @@ -742,44 +869,46 @@ def UpSampleNearest2dDynamicFactor_basic(module, tu: TestUtils): class UpSampleNearest2dSameFactor(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, inputVec): - return torch._C._nn.upsample_nearest2d(inputVec, - output_size=[8, 8], - scales_h=2.0, - scales_w=2.0) + return torch._C._nn.upsample_nearest2d( + inputVec, output_size=[8, 8], scales_h=2.0, scales_w=2.0 + ) @register_test_case(module_factory=lambda: UpSampleNearest2dSameFactor()) def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 4)) + + class Conv1dModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, inputVec, weight, bias): - return torch.ops.aten.conv1d(inputVec, - weight, - bias=bias, - stride=[1], - padding=[0], - dilation=[1], - groups=1) + return torch.ops.aten.conv1d( + inputVec, weight, bias=bias, stride=[1], padding=[0], dilation=[1], groups=1 + ) + + @register_test_case(module_factory=lambda: Conv1dModule()) def Conv1dModule_basic(module, tu: TestUtils): inputVec = tu.rand(2, 2, 6) @@ -787,25 +916,32 @@ def Conv1dModule_basic(module, tu: TestUtils): bias = torch.randn(8) module.forward(inputVec, weight, bias) + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, inputVec, weight, bias): - return torch.ops.aten.conv2d(inputVec, - weight, - bias=bias, - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - groups=1) + return torch.ops.aten.conv2d( + inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1, + ) + + @register_test_case(module_factory=lambda: Conv2dModule()) def Conv2dModule_basic(module, tu: TestUtils): inputVec = tu.rand(2, 2, 6, 6) @@ -813,25 +949,32 @@ def Conv2dModule_basic(module, tu: TestUtils): bias = torch.randn(8) module.forward(inputVec, weight, bias) + class Conv3dModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, inputVec, weight, bias): - return torch.ops.aten.conv3d(inputVec, - weight, - bias=bias, - stride=[1, 1, 1], - padding=[0, 0, 0], - dilation=[1, 1, 1], - groups=1) + return torch.ops.aten.conv3d( + inputVec, + weight, + bias=bias, + stride=[1, 1, 1], + padding=[0, 0, 0], + dilation=[1, 1, 1], + groups=1, + ) + + @register_test_case(module_factory=lambda: Conv3dModule()) def Conv3dModule_basic(module, tu: TestUtils): inputVec = tu.rand(2, 2, 6, 6, 6) @@ -839,36 +982,43 @@ def Conv3dModule_basic(module, tu: TestUtils): bias = torch.randn(8) module.forward(inputVec, weight, bias) + class ConvTbcModule(torch.nn.Module): def __init__(self): super().__init__() # shapes from https://github.com/pytorch/pytorch/blob/3e8c8ce37bbfaafa8581fb48506c0a70ea54463d/test/nn/test_convolution.py#L623 @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, weight, bias): return torch.conv_tbc(x, weight, bias) + @register_test_case(module_factory=lambda: ConvTbcModule()) def ConvTbcModule_basic(module, tu: TestUtils): module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6)) + class Conv2dQInt8Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.int8, True), - ([-1, -1, -1, -1], torch.int8, True), - ([-1], torch.float, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ] + ) def forward(self, inputVec, weight, bias): inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7) inputVec = torch.dequantize(inputVec) @@ -879,13 +1029,17 @@ class Conv2dQInt8Module(torch.nn.Module): bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) bias = torch.dequantize(bias) - return torch.ops.aten.conv2d(inputVec, - weight, - bias=bias, - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - groups=1) + return torch.ops.aten.conv2d( + inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1, + ) + + @register_test_case(module_factory=lambda: Conv2dQInt8Module()) def Conv2dQInt8Module_basic(module, tu: TestUtils): inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/custom_op_example.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/custom_op_example.py index 3d08708d7..c2f87d319 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/custom_op_example.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/custom_op_example.py @@ -17,15 +17,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # the PyTorch op registry permanently. import torch_mlir._torch_mlir_custom_op_example + class CustomOpExampleModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops._torch_mlir_custom_op_example.identity(a) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py index d54bd11cb..6371f9a8d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py @@ -10,16 +10,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class DiagonalModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diagonal(a) @@ -28,96 +30,122 @@ class DiagonalModule(torch.nn.Module): def DiagonalModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3)) + @register_test_case(module_factory=lambda: DiagonalModule()) def DiagonalModule_nonsquare(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class DiagonalTransposedModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diagonal(a, dim1=1, dim2=0) + @register_test_case(module_factory=lambda: DiagonalTransposedModule()) def DiagonalModule_transposed(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class DiagonalWithDimsModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diagonal(a, dim1=0, dim2=1) + @register_test_case(module_factory=lambda: DiagonalWithDimsModule()) def DiagonalModule_with_dims(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class DiagonalWithNegativeDimsModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1) + @register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule()) def DiagonalModule_with_negative_dims(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class DiagonalWithOffsetModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diagonal(a, offset=1) + @register_test_case(module_factory=lambda: DiagonalWithOffsetModule()) def DiagonalModule_with_offset(module, tu: TestUtils): module.forward(tu.rand(4, 6)) + # ============================================================================== + class DiagonalWithDimsOffsetModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1) + @register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule()) def DiagonalModule_with_dims_and_offset(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index cbd2868b7..b1816664b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -20,15 +20,16 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class ElementwiseUnaryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.tanh(a) @@ -42,15 +43,16 @@ def ElementwiseUnaryModule_basic(module, tu: TestUtils): class ElementwiseUnaryIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.tanh(a) @@ -64,15 +66,16 @@ def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): class ElementwiseSinhModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.sinh(a) @@ -86,15 +89,16 @@ def ElementwiseSinhModule_basic(module, tu: TestUtils): class ElementwiseSinhIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.sinh(a) @@ -108,15 +112,16 @@ def ElementwiseSinhIntModule_basic(module, tu: TestUtils): class ElementwiseCoshModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.cosh(a) @@ -130,15 +135,16 @@ def ElementwiseCoshModule_basic(module, tu: TestUtils): class ElementwiseCoshIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.cosh(a) @@ -152,15 +158,16 @@ def ElementwiseCoshIntModule_basic(module, tu: TestUtils): class ElementwiseAcoshModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.acosh(a) @@ -174,15 +181,16 @@ def ElementwiseAcoshModule_basic(module, tu: TestUtils): class ElementwiseAcoshIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.acosh(a) @@ -196,15 +204,16 @@ def ElementwiseAcoshIntModule_basic(module, tu: TestUtils): class ElementwiseAsinModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.asin(a) @@ -218,15 +227,16 @@ def ElementwiseAsinModule_basic(module, tu: TestUtils): class ElementwiseAsinIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.asin(a) @@ -240,15 +250,16 @@ def ElementwiseAsinIntModule_basic(module, tu: TestUtils): class ElementwiseAsinhModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.asinh(a) @@ -262,15 +273,16 @@ def ElementwiseAsinhModule_basic(module, tu: TestUtils): class ElementwiseAsinhIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.asinh(a) @@ -284,15 +296,16 @@ def ElementwiseAsinhIntModule_basic(module, tu: TestUtils): class ElementwiseAtanhModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.atanh(a) @@ -306,15 +319,16 @@ def ElementwiseAtanhModule_basic(module, tu: TestUtils): class ElementwiseAtanhIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.atanh(a) @@ -328,16 +342,17 @@ def ElementwiseAtanhIntModule_basic(module, tu: TestUtils): class ElementwiseBinaryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, a, b): return a * b @@ -351,22 +366,22 @@ def ElementwiseBinaryModule_basic(module, tu: TestUtils): class ElementwiseBinaryStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 4, 3, 3, 1], torch.float32, True), - ([4, 3, 1, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 4, 3, 3, 1], torch.float32, True), + ([4, 3, 1, 2], torch.float32, True), + ] + ) def forward(self, a, b): return a * b -@register_test_case( - module_factory=lambda: ElementwiseBinaryStaticShapeModule()) +@register_test_case(module_factory=lambda: ElementwiseBinaryStaticShapeModule()) def ElementwiseBinaryStaticShapeModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 4, 3, 3, 1), tu.rand(4, 3, 1, 2)) @@ -375,17 +390,18 @@ def ElementwiseBinaryStaticShapeModule_basic(module, tu: TestUtils): class ElementwiseTernaryModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, a, b, c): return torch.lerp(a, b, c) @@ -399,41 +415,45 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils): class ElementwiseAtenWhereSelfModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 5, 5], torch.bool, True), - ([1, 12, 5, 5], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 5, 5], torch.bool, True), + ([1, 12, 5, 5], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, a, b, c): return torch.ops.aten.where(a, b, c) @register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule()) def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils): - module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), tu.rand(1, 12, 5, 5), tu.rand()) + module.forward( + torch.zeros(1, 1, 5, 5, dtype=torch.bool), tu.rand(1, 12, 5, 5), tu.rand() + ) # ============================================================================== class ElementwiseWhereSelfModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, a, b, c): return torch.where(a > 0.5, b, c) @@ -447,15 +467,16 @@ def ElementwiseWhereSelfModule_basic(module, tu: TestUtils): class ElementwiseWhereScalarModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.where(a > 0.5, 4.0, 8.0).to(torch.float) @@ -469,16 +490,17 @@ def ElementwiseWhereScalarModule_basic(module, tu: TestUtils): class ElementwiseWhereScalarOtherModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ([-1, -1], torch.float64, True), + ] + ) def forward(self, a, b): return torch.where(a > 0.5, b, 8.0) @@ -492,16 +514,17 @@ def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils): class ElementwiseWhereScalarOtherStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float64, True), - ([4, 5], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float64, True), + ([4, 5], torch.float64, True), + ] + ) def forward(self, a, b): return torch.where(a > 0.5, b, 8) @@ -515,16 +538,17 @@ def ElementwiseWhereScalarOtherStaticModule_basic(module, tu: TestUtils): class ElementwiseWhereScalarSelfModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ([-1, -1], torch.float64, True), + ] + ) def forward(self, a, b): return torch.where(a > 0.5, 4.0, b) @@ -533,20 +557,22 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module): def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double()) + # ============================================================================== class ElementwiseWhereScalarSelfStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float64, True), - ([4, 5], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float64, True), + ([4, 5], torch.float64, True), + ] + ) def forward(self, a, b): return torch.where(a > 0.5, 4.0, b) @@ -560,27 +586,26 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): class ElementwiseNanToNumModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.float32, True) - ]) + @annotate_args([None, ([3, 4], torch.float32, True)]) def forward(self, a): return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0) + @register_test_case(module_factory=lambda: ElementwiseNanToNumModule()) def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): - module.forward(torch.tensor( - [ - [float('nan'), 0.0, float('nan'), 0.0], - [float('inf'), 0.0, float('inf'), 0.0], - [float('-inf'), 0.0, float('-inf'), 0.0] - ] - )) + module.forward( + torch.tensor( + [ + [float("nan"), 0.0, float("nan"), 0.0], + [float("inf"), 0.0, float("inf"), 0.0], + [float("-inf"), 0.0, float("-inf"), 0.0], + ] + ) + ) # ============================================================================== @@ -589,16 +614,17 @@ def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): # Addition is an interesting special case of a binary op, because under the hood # it carries a third scalar "alpha" parameter, which needs special handling. class ElementwiseAddModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, a, b): return a + b @@ -612,40 +638,40 @@ def ElementwiseAddModule_basic(module, tu: TestUtils): class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, a, b): return a * b.unsqueeze(0) -@register_test_case( - module_factory=lambda: ElementwiseUnsqueezeBroadcastModule()) +@register_test_case(module_factory=lambda: ElementwiseUnsqueezeBroadcastModule()) def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand()) - # ============================================================================== class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): # As mentioned in `unsqueeze` docstring, # valid dim values are [-input.dim()-1, input.dim()+1). @@ -662,16 +688,17 @@ def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils): class ElementwiseFlattenBroadcastModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, a, b): return a * b.flatten(-1, -1) @@ -685,15 +712,16 @@ def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils): class ElementwiseReluModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.relu(x) @@ -705,80 +733,95 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): # ============================================================================== -class QuantizedReluInt8(torch.nn.Module): +class QuantizedReluInt8(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ] + ) def forward(self, x): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) qx = torch.dequantize(qx) return torch.relu(qx) + @register_test_case(module_factory=lambda: QuantizedReluInt8()) def QuantizedReluInt8_basic(module, tu: TestUtils): module.forward(tu.randint(7, 4, low=-128, high=127).to(torch.int8)) - + + # ============================================================================== -class QuantizedReluUint8(torch.nn.Module): +class QuantizedReluUint8(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.uint8, True), + ] + ) def forward(self, x): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) qx = torch.dequantize(qx) return torch.relu(qx) + @register_test_case(module_factory=lambda: QuantizedReluUint8()) def QuantizedReluUint8_basic(module, tu: TestUtils): module.forward(tu.randint(7, 4, low=0, high=255).to(torch.uint8)) - + + # ============================================================================== -class QuantizedReluInt32(torch.nn.Module): +class QuantizedReluInt32(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) qx = torch.dequantize(qx) return torch.relu(qx) + @register_test_case(module_factory=lambda: QuantizedReluInt32()) def QuantizedReluInt32_basic(module, tu: TestUtils): - module.forward(tu.randint(7, 4, low=(-2**31), high=(2**31 - 1)).to(torch.int32)) - + module.forward( + tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32) + ) + + # ============================================================================== class ElementwiseRelu6Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.relu6(x) @@ -792,15 +835,16 @@ def ElementwiseRelu6Module_basic(module, tu: TestUtils): class ElementwiseLeakyReluModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.leaky_relu(x, negative_slope=0.1) @@ -811,15 +855,16 @@ def ElementwiseLeakyReluModule_basic(module, tu: TestUtils): class ElementwiseLeakyReluStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.leaky_relu(x, negative_slope=0.1) @@ -833,143 +878,156 @@ def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils): class ElementwiseLerpScalarIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.lerp(a, b, weight=2) + @register_test_case(module_factory=lambda: ElementwiseLerpScalarIntModule()) def ElementwiseLerpScalarIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5,3), tu.rand(5,3)) + module.forward(tu.rand(5, 3), tu.rand(5, 3)) class ElementwiseLerpScalarFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.lerp(a, b, weight=0.5) + @register_test_case(module_factory=lambda: ElementwiseLerpScalarFloatModule()) def ElementwiseLerpScalarFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5,3), tu.rand(5,3)) + module.forward(tu.rand(5, 3), tu.rand(5, 3)) # ============================================================================== class ElementwiseEluNonDefaultModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.elu(x, scale=1.5, alpha=2.0, input_scale=3.0) + @register_test_case(module_factory=lambda: ElementwiseEluNonDefaultModule()) def ElementwiseEluNonDefaultModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5,3, low=-1, high=1)) + module.forward(tu.rand(5, 3, low=-1, high=1)) # ============================================================================== class ElementwiseEluModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.elu(x) + @register_test_case(module_factory=lambda: ElementwiseEluModule()) def ElementwiseEluModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5,3, low=-1, high=1)) + module.forward(tu.rand(5, 3, low=-1, high=1)) # ============================================================================== class ElementwisePreluModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, weight): return torch.ops.aten.prelu(x, weight) + @register_test_case(module_factory=lambda: ElementwisePreluModule()) def ElementwisePreluModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2, 1, low=-1, high=1), tu.rand(1) ) + module.forward(tu.rand(5, 4, 3, 2, 1, low=-1, high=1), tu.rand(1)) # ============================================================================== class ElementwisePreluStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 4, 3, 2, 1], torch.float32, True), - ([1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 4, 3, 2, 1], torch.float32, True), + ([1], torch.float32, True), + ] + ) def forward(self, x, weight): return torch.ops.aten.prelu(x, weight) + @register_test_case(module_factory=lambda: ElementwisePreluStaticModule()) def ElementwisePreluStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2, 1, low=-1, high=1), tu.rand(1) ) + module.forward(tu.rand(5, 4, 3, 2, 1, low=-1, high=1), tu.rand(1)) # ============================================================================== class ElementwiseGeluModule(torch.nn.Module): - def __init__(self): super().__init__() self.gelu = torch.nn.GELU() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.gelu(x) @@ -983,16 +1041,17 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils): class ElementwiseGeluApproximateTanhModule(torch.nn.Module): - def __init__(self): super().__init__() self.gelu = torch.nn.GELU(approximate="tanh") @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.gelu(x) @@ -1006,18 +1065,20 @@ def ElementwiseGeluApproximateTanhModule_basic(module, tu: TestUtils): class ElementwiseSeluModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.selu(x) + @register_test_case(module_factory=lambda: ElementwiseSeluModule()) def ElementwiseSeluModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 3, low=-1, high=1)) @@ -1027,15 +1088,16 @@ def ElementwiseSeluModule_basic(module, tu: TestUtils): class ElementwiseSigmoidModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.sigmoid(x) @@ -1049,15 +1111,16 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils): class ElementwiseSigmoidIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.sigmoid(x) @@ -1071,16 +1134,17 @@ def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils): class ElementwiseMinimumModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.minimum(x, y) @@ -1094,16 +1158,17 @@ def ElementwiseMinimumModule_basic(module, tu: TestUtils): class ElementwiseMinimumIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.minimum(x, y) @@ -1117,16 +1182,17 @@ def ElementwiseMinimumIntModule_basic(module, tu: TestUtils): class ElementwiseMinOtherModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return x.min(y) @@ -1140,16 +1206,17 @@ def ElementwiseMinOtherModule_basic(module, tu: TestUtils): class ElementwiseMinOtherIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return x.min(y) @@ -1163,16 +1230,17 @@ def ElementwiseMinOtherIntModule_basic(module, tu: TestUtils): class ElementwiseMaximumModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.maximum(x, y) @@ -1186,16 +1254,17 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils): class ElementwiseMaximumIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.maximum(x, y) @@ -1209,16 +1278,17 @@ def ElementwiseMaximumIntModule_basic(module, tu: TestUtils): class ElementwiseMaxOtherModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return x.max(y) @@ -1232,16 +1302,17 @@ def ElementwiseMaxOtherModule_basic(module, tu: TestUtils): class ElementwiseMaxOtherIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return x.max(y) @@ -1255,15 +1326,16 @@ def ElementwiseMaxOtherIntModule_basic(module, tu: TestUtils): class ElementwiseClampModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): float_min = torch.clamp(x, min=-2.0) int_min = torch.clamp(x, min=-3) @@ -1282,15 +1354,16 @@ def ElementwiseClampModule_basic(module, tu: TestUtils): class ElementwiseClampMinModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): float_min = torch.ops.aten.clamp_min(x, min=-2.0) int_min = torch.ops.aten.clamp_min(x, min=2) @@ -1307,15 +1380,16 @@ def ElementwiseClampMinModule_basic(module, tu: TestUtils): class ElementwiseClampMaxModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): float_max = torch.ops.aten.clamp_max(x, max=2.0) int_max = torch.ops.aten.clamp_max(x, max=3) @@ -1332,17 +1406,18 @@ def ElementwiseClampMaxModule_basic(module, tu: TestUtils): class ElementwiseClampTensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, x, min, max): min_clamp = torch.clamp(x, min) max_clamp = torch.clamp(x, max=max) @@ -1352,24 +1427,27 @@ class ElementwiseClampTensorFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseClampTensorFloatModule()) def ElementwiseClampTensorFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5, low=-10, high=10), torch.tensor([-5.0]), torch.tensor([5.0])) + module.forward( + tu.rand(3, 5, low=-10, high=10), torch.tensor([-5.0]), torch.tensor([5.0]) + ) # ============================================================================== class ElementwiseClampTensorIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, x, min, max): min_clamp = torch.clamp(x, min) max_clamp = torch.clamp(x, max=max) @@ -1379,22 +1457,20 @@ class ElementwiseClampTensorIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseClampTensorIntModule()) def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 5, low=-10, high=10), torch.tensor([-5]), torch.tensor([5])) + module.forward( + tu.randint(3, 5, low=-10, high=10), torch.tensor([-5]), torch.tensor([5]) + ) # ============================================================================== class ElementwiseClampTensorInt8Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True) - ]) + @annotate_args([None, ([-1, -1], torch.int8, True)]) def forward(self, x): min = -5 max = 5 @@ -1412,18 +1488,18 @@ def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils): # ============================================================================== - class ElementwiseClampMinTensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, x, min): return torch.ops.aten.clamp_min(x, min=min) @@ -1437,16 +1513,17 @@ def ElementwiseClampMinTensorFloatModule_basic(module, tu: TestUtils): class ElementwiseClampMinTensorIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, x, min): return torch.ops.aten.clamp_min(x, min=min) @@ -1460,15 +1537,16 @@ def ElementwiseClampMinTensorIntModule_basic(module, tu: TestUtils): class RsubFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.rsub(x, 3.0, alpha=1.0) @@ -1482,15 +1560,16 @@ def RsubFloatModule_basic(module, tu: TestUtils): class RsubFloatModule_noalpha(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.rsub(x, 2.0) @@ -1504,15 +1583,16 @@ def RsubFloatModule_noalpha_basic(module, tu: TestUtils): class RsubIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.rsub(x, 2, alpha=3) @@ -1526,35 +1606,38 @@ def RsubIntModule_basic(module, tu: TestUtils): class RsubIntModule_noalpha(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): - return torch.rsub(x, 2.) + return torch.rsub(x, 2.0) @register_test_case(module_factory=lambda: RsubIntModule_noalpha()) def RsubIntModule_noalpha_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, high=100)) + # ============================================================================== class RsubInt0d_NumToTensor_Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): x = torch.ops.prim.NumToTensor(5) return torch.rsub(x, 2) @@ -1564,19 +1647,21 @@ class RsubInt0d_NumToTensor_Module(torch.nn.Module): def RsubInt0d_NumToTensor_Module_basic(module, tu: TestUtils): module.forward() + # ============================================================================== class ElementwiseMulScalarIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.mul(x, 4) @@ -1590,15 +1675,16 @@ def ElementwiseMulScalarModule_int(module, tu: TestUtils): class ElementwiseMulScalarFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.mul(x, 100.0) @@ -1612,15 +1698,16 @@ def ElementwiseMulScalarModule_float(module, tu: TestUtils): class ElementwiseMulScalarModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.mul(x, 8.0) @@ -1634,16 +1721,17 @@ def ElementwiseMulScalarModule_basic(module, tu: TestUtils): class ElementwiseMulTensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ] + ) def forward(self, a, b): return torch.mul(a, b) @@ -1657,39 +1745,41 @@ def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils): class ElementwiseMulTensorIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ] + ) def forward(self, a, b): return torch.mul(a, b) @register_test_case(module_factory=lambda: ElementwiseMulTensorIntModule()) def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): - module.forward( - tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10)) + module.forward(tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10)) # ============================================================================== -class ElementwiseMulTensorComplexModule(torch.nn.Module): +class ElementwiseMulTensorComplexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.complex64, True), - ([-1], torch.complex64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.complex64, True), + ([-1], torch.complex64, True), + ] + ) def forward(self, a, b): return torch.mul(a, b) @@ -1697,21 +1787,25 @@ class ElementwiseMulTensorComplexModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseMulTensorComplexModule()) def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils): module.forward( - tu.randint(4, high=10).type(torch.complex64), tu.randint(4, high=10).type(torch.complex64)) + tu.randint(4, high=10).type(torch.complex64), + tu.randint(4, high=10).type(torch.complex64), + ) # ============================================================================== -class ElementwiseMishModule(torch.nn.Module): +class ElementwiseMishModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mish(x) @@ -1725,15 +1819,16 @@ def ElementwiseMishModule_basic(module, tu: TestUtils): class ElementwiseAtanTensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.atan(a) @@ -1747,39 +1842,40 @@ def ElementwiseAtanTensorFloatModule_basic(module, tu: TestUtils): class ElementwiseAtanTensorIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ] + ) def forward(self, a): return torch.atan(a) @register_test_case(module_factory=lambda: ElementwiseAtanTensorIntModule()) def ElementwiseAtanTensorIntModule_basic(module, tu: TestUtils): - module.forward( - tu.randint(4, low=1, high=10).type(torch.int32)) + module.forward(tu.randint(4, low=1, high=10).type(torch.int32)) # ============================================================================== class ElementwiseAtan2TensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.atan2(a, b) @@ -1793,19 +1889,20 @@ def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils): class ElementwiseAtan2TensorFloatStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ([4, 5, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ([4, 5, 6], torch.float32, True), + ] + ) def forward(self, a, b): return torch.atan2(a, b) - + @register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatStaticModule()) def ElementwiseAtan2TensorFloatStaticModule_basic(module, tu: TestUtils): @@ -1814,17 +1911,19 @@ def ElementwiseAtan2TensorFloatStaticModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseAtan2TensorIntModule(torch.nn.Module): +class ElementwiseAtan2TensorIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ] + ) def forward(self, a, b): return torch.atan2(a, b) @@ -1832,94 +1931,103 @@ class ElementwiseAtan2TensorIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule()) def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils): module.forward( - tu.randint(4, low=1, high=10).type(torch.int32), tu.randint(4, low=1, high=10)) + tu.randint(4, low=1, high=10).type(torch.int32), tu.randint(4, low=1, high=10) + ) # ============================================================================== class ElementwiseAtan2TensorIntStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.int32, True), - ([4, 5, 6], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([4, 5, 6], torch.int32, True), + ([4, 5, 6], torch.int64, True), + ] + ) def forward(self, a, b): return torch.atan2(a, b) - + @register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntStaticModule()) def ElementwiseAtan2TensorIntStaticModule_basic(module, tu: TestUtils): module.forward( - tu.randint(4, 5, 6, low=1, high=10).type(torch.int32), tu.randint(4, 5, 6, low=1, high=10)) + tu.randint(4, 5, 6, low=1, high=10).type(torch.int32), + tu.randint(4, 5, 6, low=1, high=10), + ) # ============================================================================== class ElementwiseAtan2FloatIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float64, True), + ] + ) def forward(self, a, b): return torch.atan2(a, b) @register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule()) def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 4, low=1, high=10).to(torch.int32), - tu.rand(4, 4).double()) + module.forward( + tu.randint(4, 4, low=1, high=10).to(torch.int32), tu.rand(4, 4).double() + ) # ============================================================================== class ElementwiseAtan2FloatIntStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.int32, True), - ([4, 5, 6], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([4, 5, 6], torch.int32, True), + ([4, 5, 6], torch.float64, True), + ] + ) def forward(self, a, b): return torch.atan2(a, b) - + @register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntStaticModule()) def ElementwiseAtan2FloatIntStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 5, 6, low=1, high=10).to(torch.int32), - tu.rand(4, 5, 6).double()) - + module.forward( + tu.randint(4, 5, 6, low=1, high=10).to(torch.int32), tu.rand(4, 5, 6).double() + ) + # ============================================================================== class ElementwiseLogModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.log(a) @@ -1933,15 +2041,16 @@ def ElementwiseLogModule_basic(module, tu: TestUtils): class ElementwiseLogIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.log(a) @@ -1950,19 +2059,21 @@ class ElementwiseLogIntModule(torch.nn.Module): def ElementwiseLogIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + # ============================================================================== class ElementwiseLog1pModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.log1p(a) @@ -1976,15 +2087,16 @@ def ElementwiseLog1pModule_basic(module, tu: TestUtils): class ElementwiseLogitModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.logit(a, eps=1e-7) @@ -1998,15 +2110,16 @@ def ElementwiseLogitModule_basic(module, tu: TestUtils): class ElementwiseErfModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.erf(a) @@ -2020,15 +2133,16 @@ def ElementwiseErfModule_basic(module, tu: TestUtils): class ElementwiseErfIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.ops.aten.erf(a) @@ -2042,15 +2156,16 @@ def ElementwiseErfIntModule_basic(module, tu: TestUtils): class ElementwiseSqrtModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.sqrt(a) @@ -2064,15 +2179,16 @@ def ElementwiseSqrtModule_basic(module, tu: TestUtils): class ElementwiseSqrtIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.sqrt(a) @@ -2086,15 +2202,16 @@ def ElementwiseSqrtIntModule_basic(module, tu: TestUtils): class ElementwiseFloorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.floor(a) @@ -2103,16 +2220,18 @@ class ElementwiseFloorModule(torch.nn.Module): def ElementwiseFloorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) -class ElementwiseFloorIntModule(torch.nn.Module): +class ElementwiseFloorIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.floor(a) @@ -2126,15 +2245,16 @@ def ElementwiseFloorIntModule_basic(module, tu: TestUtils): class ElementwiseCeilModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ceil(a) @@ -2148,15 +2268,16 @@ def ElementwiseCeilModule_basic(module, tu: TestUtils): class ElementwiseTruncModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 6], torch.float32, True), + ] + ) def forward(self, a): return torch.trunc(a) @@ -2170,15 +2291,16 @@ def ElementwiseTruncModule_basic(module, tu: TestUtils): class ElementwiseTruncIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.trunc(a) @@ -2192,38 +2314,41 @@ def ElementwiseTruncIntModule_basic(module, tu: TestUtils): class ElementwiseSignModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.sign(a) @register_test_case(module_factory=lambda: ElementwiseSignModule()) def ElementwiseSignModule_basic(module, tu: TestUtils): - module.forward(torch.tensor([[-2.0, 0.0, 1.1, 2.0], - [6.0, -0.0, torch.inf, -torch.inf]])) + module.forward( + torch.tensor([[-2.0, 0.0, 1.1, 2.0], [6.0, -0.0, torch.inf, -torch.inf]]) + ) # ============================================================================== class ElementwiseSignIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.sign(a) @@ -2237,15 +2362,16 @@ def ElementwiseSignIntModule_basic(module, tu: TestUtils): class ElementwiseSgnModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.sgn(a) @@ -2259,15 +2385,16 @@ def ElementwiseSgnModule_basic(module, tu: TestUtils): class ElementwisePowModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.pow(a, 2.0) @@ -2281,16 +2408,17 @@ def ElementwisePowModule_basic(module, tu: TestUtils): class ElementwisePowTensorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.pow(a, b) @@ -2304,16 +2432,17 @@ def ElementwisePowTensorModule_basic(module, tu: TestUtils): class ElementwisePowTensorStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.float32, True), - ([1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ([1, 1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.pow(a, b) @@ -2327,16 +2456,17 @@ def ElementwisePowTensorStaticModule_basic(module, tu: TestUtils): class ElementwisePowTensorBroadcastModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, 1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, 1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.pow(a, b) @@ -2350,16 +2480,17 @@ def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils): class ElementwisePowTensorBroadcastStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1], torch.float32, True), - ([3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 1], torch.float32, True), + ([3, 4], torch.float32, True), + ] + ) def forward(self, a, b): return torch.pow(a, b) @@ -2373,15 +2504,16 @@ def ElementwisePowTensorBroadcastStaticModule_basic(module, tu: TestUtils): class ElementwisePowScalarModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) def forward(self, exp): return torch.pow(2.0, exp) @@ -2395,7 +2527,6 @@ def ElementwisePowScalarModule_basic(module, tu: TestUtils): class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -2414,7 +2545,6 @@ def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils): class ElementwiseToDtypeIdentityModule(torch.nn.Module): - def __init__(self): super().__init__() @@ -2433,7 +2563,6 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils): class ElementwiseToDtypeI64ToI8Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -2452,7 +2581,6 @@ def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils): class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -2471,15 +2599,16 @@ def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils): class ElementwiseLog2Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.log2(a) @@ -2493,15 +2622,16 @@ def ElementwiseLog2Module_basic(module, tu: TestUtils): class ElementwiseLog2IntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.log2(a) @@ -2513,16 +2643,18 @@ def ElementwiseLog2IntModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseLog10Module(torch.nn.Module): +class ElementwiseLog10Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.log10(a) @@ -2534,16 +2666,18 @@ def ElementwiseLog10Module_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseLog10IntModule(torch.nn.Module): +class ElementwiseLog10IntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.log10(a) @@ -2557,15 +2691,16 @@ def ElementwiseLog10IntModule_basic(module, tu: TestUtils): class ElementwiseRsqrtModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.rsqrt(a) @@ -2579,15 +2714,16 @@ def ElementwiseRsqrtModule_basic(module, tu: TestUtils): class ElementwiseRsqrtIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.rsqrt(a) @@ -2601,15 +2737,16 @@ def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): class ElementwiseAbsFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.abs(a) @@ -2623,15 +2760,16 @@ def ElementwiseAbsFloatModule_basic(module, tu: TestUtils): class ElementwiseAbsIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.abs(a) @@ -2645,15 +2783,16 @@ def ElementwiseAbsIntModule_basic(module, tu: TestUtils): class ElementwiseReciprocalModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return torch.reciprocal(a) @@ -2667,15 +2806,16 @@ def ElementwiseReciprocalModule_basic(module, tu: TestUtils): class ElementwiseReciprocalIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ] + ) def forward(self, a): return torch.reciprocal(a) @@ -2689,15 +2829,16 @@ def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils): class ElementwiseDivScalarModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.div(x, 10.0) @@ -2711,15 +2852,16 @@ def ElementwiseDivScalarModule_basic(module, tu: TestUtils): class ElementwiseAtenDivIntScalarModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.ops.aten.div(x, 128) @@ -2728,19 +2870,21 @@ class ElementwiseAtenDivIntScalarModule(torch.nn.Module): def ElementwiseAtenDivIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4)) + # ============================================================================== class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ] + ) def forward(self, x): return torch.remainder(x, 2.0) @@ -2754,15 +2898,16 @@ def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils): class ElementwiseRemainderScalarModule_Float(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.remainder(x, 2.0) @@ -2774,16 +2919,18 @@ def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseRemainderScalarModule_Int(torch.nn.Module): +class ElementwiseRemainderScalarModule_Int(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.remainder(x, 2) @@ -2792,18 +2939,21 @@ class ElementwiseRemainderScalarModule_Int(torch.nn.Module): def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils): module.forward(tu.randint(3, 2, high=10).to(torch.int32)) + # ============================================================================== -class ElementwiseRemainderScalarModule_Bool(torch.nn.Module): +class ElementwiseRemainderScalarModule_Bool(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) def forward(self, x): return torch.remainder(x, 2) @@ -2814,18 +2964,14 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils): # ============================================================================== - -class ElementwiseFmodTensor_Float(torch.nn.Module): + +class ElementwiseFmodTensor_Float(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True) - ]) + @annotate_args([None, ([-1], torch.float32, True), ([-1], torch.float32, True)]) def forward(self, x, y): return torch.fmod(x, y) @@ -2833,62 +2979,69 @@ class ElementwiseFmodTensor_Float(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseFmodTensor_Float()) def ElementwiseFmodTensor_Float_basic(module, tu: TestUtils): module.forward(tu.rand(100, low=-10, high=10), tu.rand(100, low=-10, high=10)) - -# ============================================================================== - -class ElementwiseFmodTensor_Int_Float(torch.nn.Module): + +# ============================================================================== + + +class ElementwiseFmodTensor_Int_Float(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ([-1], torch.float32, True) - ]) + @annotate_args([None, ([-1], torch.int32, True), ([-1], torch.float32, True)]) def forward(self, x, y): return torch.fmod(x, y) @register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int_Float()) def ElementwiseFmodTensor_Int_Float_basic(module, tu: TestUtils): - module.forward(tu.randint(100, low=-10, high=10).to(torch.int32), tu.rand(100, low=-10, high=10)) + module.forward( + tu.randint(100, low=-10, high=10).to(torch.int32), + tu.rand(100, low=-10, high=10), + ) + # ============================================================================== - -class ElementwiseFmodTensor_Int(torch.nn.Module): + +class ElementwiseFmodTensor_Int(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ([-1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ([-1], torch.int32, True), + ] + ) def forward(self, x, y): return torch.fmod(x, y) @register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int()) def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils): - module.forward(tu.randint(100, low=0, high=1000).to(torch.int32), tu.randint(100, low=1, high=1000).to(torch.int32)) + module.forward( + tu.randint(100, low=0, high=1000).to(torch.int32), + tu.randint(100, low=1, high=1000).to(torch.int32), + ) # ============================================================================== class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.remainder(a, b) @@ -2902,16 +3055,17 @@ def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils): class ElementwiseRemainderTensorModule_Float(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.remainder(a, b) @@ -2923,39 +3077,46 @@ def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseRemainderTensorModule_Int(torch.nn.Module): +class ElementwiseRemainderTensorModule_Int(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a, b): return torch.remainder(a, b) @register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Int()) def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, high=10, dtype=torch.int32), tu.randint(3, 4, high=10, dtype=torch.int32)) + module.forward( + tu.randint(3, 4, high=10, dtype=torch.int32), + tu.randint(3, 4, high=10, dtype=torch.int32), + ) + # ============================================================================== class ElementwiseDivTensorFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ] + ) def forward(self, a, b): return torch.div(a, b) @@ -2969,311 +3130,353 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils): class ElementwiseDivTensorIntegerModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a, b): return torch.div(a, b) @register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule()) def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32)) + module.forward( + tu.randint(3, 4, low=-10, high=10), + tu.randint(3, 4, low=-10, high=10).type(torch.int32), + ) # ============================================================================== class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.uint8, True), - ([-1, -1], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.uint8, True), + ([-1, -1], torch.uint8, True), + ] + ) def forward(self, a, b): return torch.div(a, b) @register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule()) def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8)) + module.forward( + tu.randint(3, 4, low=0, high=10).to(torch.uint8), + tu.randint(3, 4, low=0, high=10).type(torch.uint8), + ) # ============================================================================== - class ElementwiseDivScalarRoundingModeTruncModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return torch.div(a, 0.5, rounding_mode="trunc") @register_test_case( - module_factory=lambda: ElementwiseDivScalarRoundingModeTruncModule()) + module_factory=lambda: ElementwiseDivScalarRoundingModeTruncModule() +) def ElementwiseDivScalarRoundingModeTruncModule_basic(module, tu: TestUtils): module.forward(tu.rand(4)) class ElementwiseDivScalarRoundingModeFloorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.div(a, 0.5, rounding_mode="floor") @register_test_case( - module_factory=lambda: ElementwiseDivScalarRoundingModeFloorModule()) + module_factory=lambda: ElementwiseDivScalarRoundingModeFloorModule() +) def ElementwiseDivScalarRoundingModeFloorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) -class ElementwiseDivScalarRoundingModeTruncStaticModule(torch.nn.Module): +class ElementwiseDivScalarRoundingModeTruncStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4], torch.float32, True), + ] + ) def forward(self, a): return torch.div(a, 0.5, rounding_mode="trunc") @register_test_case( - module_factory=lambda: ElementwiseDivScalarRoundingModeTruncStaticModule()) + module_factory=lambda: ElementwiseDivScalarRoundingModeTruncStaticModule() +) def ElementwiseDivScalarRoundingModeTruncStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4)) class ElementwiseDivScalarRoundingModeFloorStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) def forward(self, a): return torch.div(a, 0.5, rounding_mode="floor") @register_test_case( - module_factory=lambda: ElementwiseDivScalarRoundingModeFloorStaticModule()) + module_factory=lambda: ElementwiseDivScalarRoundingModeFloorStaticModule() +) def ElementwiseDivScalarRoundingModeFloorStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) -class ElementwiseDivScalarRoundingModeTruncIntStaticModule(torch.nn.Module): +class ElementwiseDivScalarRoundingModeTruncIntStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ] + ) def forward(self, a): return torch.div(a, 3, rounding_mode="trunc") @register_test_case( - module_factory=lambda: ElementwiseDivScalarRoundingModeTruncIntStaticModule()) + module_factory=lambda: ElementwiseDivScalarRoundingModeTruncIntStaticModule() +) def ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32)) class ElementwiseDivScalarRoundingModeFloorIntStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ] + ) def forward(self, a): return torch.div(a, 3, rounding_mode="floor") @register_test_case( - module_factory=lambda: ElementwiseDivScalarRoundingModeFloorIntStaticModule()) + module_factory=lambda: ElementwiseDivScalarRoundingModeFloorIntStaticModule() +) def ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32)) - + # ============================================================================== class ElementwiseDivTensorRoundingModeTruncModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ] + ) def forward(self, a, b): return torch.div(a, b, rounding_mode="trunc") @register_test_case( - module_factory=lambda: ElementwiseDivTensorRoundingModeTruncModule()) + module_factory=lambda: ElementwiseDivTensorRoundingModeTruncModule() +) def ElementwiseDivTensorRoundingModeTruncModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) class ElementwiseDivTensorRoundingModeFloorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float64, True), + ] + ) def forward(self, a, b): return torch.div(a, b, rounding_mode="floor") @register_test_case( - module_factory=lambda: ElementwiseDivTensorRoundingModeFloorModule()) + module_factory=lambda: ElementwiseDivTensorRoundingModeFloorModule() +) def ElementwiseDivTensorRoundingModeFloorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64)) -class ElementwiseDivTensorRoundingModeTruncStaticModule(torch.nn.Module): +class ElementwiseDivTensorRoundingModeTruncStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4], torch.float32, True), - ([4], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([4], torch.float32, True), + ([4], torch.float64, True), + ] + ) def forward(self, a, b): return torch.div(a, b, rounding_mode="trunc") @register_test_case( - module_factory=lambda: ElementwiseDivTensorRoundingModeTruncStaticModule()) + module_factory=lambda: ElementwiseDivTensorRoundingModeTruncStaticModule() +) def ElementwiseDivTensorRoundingModeTruncStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) class ElementwiseDivTensorRoundingModeFloorStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.float32, True), - ([3, 4], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ([3, 4], torch.float64, True), + ] + ) def forward(self, a, b): return torch.div(a, b, rounding_mode="floor") @register_test_case( - module_factory=lambda: ElementwiseDivTensorRoundingModeFloorStaticModule()) + module_factory=lambda: ElementwiseDivTensorRoundingModeFloorStaticModule() +) def ElementwiseDivTensorRoundingModeFloorStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64)) -class ElementwiseDivTensorRoundingModeTruncIntStaticModule(torch.nn.Module): +class ElementwiseDivTensorRoundingModeTruncIntStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ([3, 4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ([3, 4], torch.int64, True), + ] + ) def forward(self, a, b): return torch.div(a, b, rounding_mode="trunc") @register_test_case( - module_factory=lambda: ElementwiseDivTensorRoundingModeTruncIntStaticModule()) + module_factory=lambda: ElementwiseDivTensorRoundingModeTruncIntStaticModule() +) def ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64)) + module.forward( + tu.randint(3, 4, low=-10, high=10).type(torch.int32), + tu.randint(3, 4, low=1, high=10).type(torch.int64), + ) class ElementwiseDivTensorRoundingModeFloorIntStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ([3, 4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ([3, 4], torch.int64, True), + ] + ) def forward(self, a, b): return torch.div(a, b, rounding_mode="floor") @register_test_case( - module_factory=lambda: ElementwiseDivTensorRoundingModeFloorIntStaticModule()) + module_factory=lambda: ElementwiseDivTensorRoundingModeFloorIntStaticModule() +) def ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64)) + module.forward( + tu.randint(3, 4, low=-10, high=10).type(torch.int32), + tu.randint(3, 4, low=1, high=10).type(torch.int64), + ) # ============================================================================== class ElementwiseBitwiseAndModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.bitwise_and(x, y) @@ -3282,23 +3485,25 @@ class ElementwiseBitwiseAndModule(torch.nn.Module): def ElementwiseBitwiseAndModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(3, 4, low=-10, high=10)) + tu.randint(3, 4, low=-10, high=10), + ) # ============================================================================== class ElementwiseBitwiseAndStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ([4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ([4], torch.int64, True), + ] + ) def forward(self, x, y): return torch.bitwise_and(x, y) @@ -3307,23 +3512,25 @@ class ElementwiseBitwiseAndStaticShapeModule(torch.nn.Module): def ElementwiseBitwiseAndStaticShapeModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(4, low=-10, high=10)) + tu.randint(4, low=-10, high=10), + ) # ============================================================================== class ElementwiseBitwiseOrModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.bitwise_or(x, y) @@ -3332,23 +3539,25 @@ class ElementwiseBitwiseOrModule(torch.nn.Module): def ElementwiseBitwiseOrModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(3, 4, low=-10, high=10)) + tu.randint(3, 4, low=-10, high=10), + ) # ============================================================================== class ElementwiseBitwiseOrStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ([4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ([4], torch.int64, True), + ] + ) def forward(self, x, y): return torch.bitwise_or(x, y) @@ -3357,23 +3566,25 @@ class ElementwiseBitwiseOrStaticShapeModule(torch.nn.Module): def ElementwiseBitwiseOrStaticShapeModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(4, low=-10, high=10)) + tu.randint(4, low=-10, high=10), + ) # ============================================================================== class ElementwiseOrTensorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.__or__(x, y) @@ -3382,23 +3593,25 @@ class ElementwiseOrTensorModule(torch.nn.Module): def ElementwiseOrTensorModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(3, 4, low=-10, high=10)) + tu.randint(3, 4, low=-10, high=10), + ) # ============================================================================== class ElementwiseOrTensorStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ([4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ([4], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.__or__(x, y) @@ -3407,68 +3620,66 @@ class ElementwiseOrTensorStaticShapeModule(torch.nn.Module): def ElementwiseOrTensorStaticShapeModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(4, low=-10, high=10)) + tu.randint(4, low=-10, high=10), + ) # ============================================================================== class ElementwiseAndscalarModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.ops.aten.__and__(x, 12) @register_test_case(module_factory=lambda: ElementwiseAndscalarModule()) def ElementwiseAndScalarModule_basic(module, tu: TestUtils): - module.forward( - tu.randint(3, 4, low=-10, high=10).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32)) # ============================================================================== class ElementwiseAndScalarStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True) - ]) + @annotate_args([None, ([3, 4], torch.int32, True)]) def forward(self, x): return torch.ops.aten.__and__(x, 12) @register_test_case(module_factory=lambda: ElementwiseAndScalarStaticShapeModule()) def ElementwiseAndScalarStaticShapeModule_basic(module, tu: TestUtils): - module.forward( - tu.randint(3, 4, low=-10, high=10).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32)) + # ============================================================================== class ElementwiseBitwiseXorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.bitwise_xor(x, y) @@ -3477,23 +3688,25 @@ class ElementwiseBitwiseXorModule(torch.nn.Module): def ElementwiseBitwiseXorModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(3, 4, low=-10, high=10)) + tu.randint(3, 4, low=-10, high=10), + ) # ============================================================================== class ElementwiseBitwiseXorStaticShapeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int32, True), - ([4], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ([4], torch.int64, True), + ] + ) def forward(self, x, y): return torch.bitwise_xor(x, y) @@ -3502,22 +3715,24 @@ class ElementwiseBitwiseXorStaticShapeModule(torch.nn.Module): def ElementwiseBitwiseXorStaticShapeModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-10, high=10).to(torch.int32), - tu.randint(4, low=-10, high=10)) + tu.randint(4, low=-10, high=10), + ) # ============================================================================== class ElementwiseBitwiseNotInt64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.bitwise_not(x) @@ -3531,15 +3746,16 @@ def ElementwiseBitwiseNotInt64Module_basic(module, tu: TestUtils): class ElementwiseBitwiseNotInt32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.bitwise_not(x) @@ -3553,40 +3769,43 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils): class ElementwiseSubTensorInt8Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ([-1, -1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ] + ) def forward(self, x, y): return torch.sub(x, y, alpha=2) @register_test_case(module_factory=lambda: ElementwiseSubTensorInt8Module()) def ElementwiseSubTensorInt8Module_basic(module, tu: TestUtils): - module.forward( + module.forward( tu.randint(3, 4, high=10).to(dtype=torch.int8), - tu.randint(3, 4, high=10).to(dtype=torch.int8)) + tu.randint(3, 4, high=10).to(dtype=torch.int8), + ) # ============================================================================== class ElementwiseSubScalarIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.sub(x, 2.1, alpha=2) @@ -3600,15 +3819,16 @@ def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils): class ElementwiseSubScalarFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.sub(x, 2.1) @@ -3622,15 +3842,16 @@ def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils): class ElementwiseAddScalarInt64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.add(x, 3.0) @@ -3644,15 +3865,16 @@ def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils): class ElementwiseAddScalarIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.add(x, 3.0) @@ -3666,15 +3888,16 @@ def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils): class ElementwiseAddScalarFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.add(x, 3.0, alpha=2) @@ -3688,21 +3911,23 @@ def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils): class ElementwiseAddScalar_NumToTensorFloat_Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): x = torch.ops.prim.NumToTensor(5.0) return torch.add(x, 3) @register_test_case( - module_factory=lambda: ElementwiseAddScalar_NumToTensorFloat_Module()) + module_factory=lambda: ElementwiseAddScalar_NumToTensorFloat_Module() +) def ElementwiseAddScalar_NumToTensorFloat_Module_basic(module, tu: TestUtils): module.forward() @@ -3711,21 +3936,23 @@ def ElementwiseAddScalar_NumToTensorFloat_Module_basic(module, tu: TestUtils): class ElementwiseAddScalar_TensorLiteralInt32_Module(torch.nn.Module): - def __init__(self): super().__init__() self.x = torch.tensor(2, dtype=torch.int32) @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return torch.add(self.x, 3) @register_test_case( - module_factory=lambda: ElementwiseAddScalar_TensorLiteralInt32_Module()) + module_factory=lambda: ElementwiseAddScalar_TensorLiteralInt32_Module() +) def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils): module.forward() @@ -3734,15 +3961,16 @@ def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils): class ElementwiseAddScalarInt8Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ] + ) def forward(self, x): return torch.add(x, 3, 2) @@ -3756,15 +3984,16 @@ def ElementwiseAddScalarInt8Module_basic(module, tu: TestUtils): class ElementwiseCloneModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.clone(x) @@ -3778,15 +4007,16 @@ def ElementwiseCloneModule_basic(module, tu: TestUtils): class ElementwiseCloneContiguousModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.clone(x, memory_format=torch.contiguous_format) @@ -3800,23 +4030,24 @@ def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.clone(x, memory_format=torch.channels_last) @register_test_case( - module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule()) -def ElementwiseCloneChannelsLastMemoryFormatModule_basic( - module, tu: TestUtils): + module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule() +) +def ElementwiseCloneChannelsLastMemoryFormatModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 5)) @@ -3824,15 +4055,16 @@ def ElementwiseCloneChannelsLastMemoryFormatModule_basic( class LiftFreshCopyModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.lift_fresh_copy(x) @@ -3846,15 +4078,16 @@ def LiftFreshCopyModule_basic(module, tu: TestUtils): class ElementwiseExpModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.exp(a) @@ -3868,15 +4101,16 @@ def ElementwiseExpModule_basic(module, tu: TestUtils): class ElementwiseExpIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.exp(a) @@ -3890,15 +4124,16 @@ def ElementwiseExpIntModule_basic(module, tu: TestUtils): class ElementwiseExpm1Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.special.expm1(a) @@ -3912,15 +4147,16 @@ def ElementwiseExpm1Module_basic(module, tu: TestUtils): class ElementwiseExpm1IntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.special.expm1(a) @@ -3934,15 +4170,16 @@ def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): class ElementwiseSinModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.sin(a) @@ -3956,15 +4193,16 @@ def ElementwiseSinModule_basic(module, tu: TestUtils): class ElementwiseSinIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.sin(a) @@ -3978,15 +4216,16 @@ def ElementwiseSinIntModule_basic(module, tu: TestUtils): class ElementwiseCosModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.cos(a) @@ -4000,15 +4239,16 @@ def ElementwiseCosModule_basic(module, tu: TestUtils): class ElementwiseCosIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.cos(a) @@ -4017,19 +4257,21 @@ class ElementwiseCosIntModule(torch.nn.Module): def ElementwiseCosIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + # ============================================================================== class ElementwiseAcosModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.acos(a) @@ -4038,19 +4280,21 @@ class ElementwiseAcosModule(torch.nn.Module): def ElementwiseAcosModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== class ElementwiseAcosIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.acos(a) @@ -4059,18 +4303,21 @@ class ElementwiseAcosIntModule(torch.nn.Module): def ElementwiseAcosIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + # ============================================================================== -class ElementwiseTanModule(torch.nn.Module): +class ElementwiseTanModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.tan(a) @@ -4079,18 +4326,21 @@ class ElementwiseTanModule(torch.nn.Module): def ElementwiseTanModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== -class ElementwiseTanIntModule(torch.nn.Module): +class ElementwiseTanIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.tan(a) @@ -4099,18 +4349,21 @@ class ElementwiseTanIntModule(torch.nn.Module): def ElementwiseTanIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + # ============================================================================== -class ElementwiseNegModule(torch.nn.Module): +class ElementwiseNegModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.neg(a) @@ -4119,152 +4372,198 @@ class ElementwiseNegModule(torch.nn.Module): def ElementwiseNegModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseAtenLogicalOrOpModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.bool, True), - ([-1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ([-1], torch.bool, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpModule()) def ElementwiseAtenLogicalOrOpModule_basic(module, tu: TestUtils): module.forward(torch.tensor([False, True]), torch.tensor([False, False])) + class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs1Module()) def ElementwiseAtenLogicalOrOpDiffArgs1Module_basic(module, tu: TestUtils): module.forward(torch.tensor([0.2, 0.1]), torch.tensor([0, 1])) + # ============================================================================== + class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.bool, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs2Module()) def ElementwiseAtenLogicalOrOpDiffArgs2Module_basic(module, tu: TestUtils): module.forward(torch.tensor([True, False]), torch.tensor([0, 1])) + # ============================================================================== + class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.bool, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs3Module()) def ElementwiseAtenLogicalOrOpDiffArgs3Module_basic(module, tu: TestUtils): module.forward(torch.tensor([1, 2]), torch.tensor([False, True])) + # ============================================================================== + class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.int64, True), - ([-1, -1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int64, True), + ([-1, -1, -1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule()) def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3, 4, 5, low=3, high=10), tu.randint(2, 3, 4, 5, low=10, high=100)) + module.forward( + tu.randint(2, 3, 4, 5, low=3, high=10), tu.randint(2, 3, 4, 5, low=10, high=100) + ) + # ============================================================================== + class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule()) + +@register_test_case( + module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule() +) def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 3, 5), tu.rand(2, 3, 3, 5)) + # ============================================================================== + class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.int64, True), - ([-1, -1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int64, True), + ([-1, -1, -1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule()) def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils): - module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)), torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100))) + module.forward( + torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)), + torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100)), + ) + # ============================================================================== + class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule()) def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, high=3), tu.randint(4, 3, high=3)) @@ -4278,16 +4577,23 @@ class ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule(torch.nn.Modul super().__init__() @export - @annotate_args([ - None, - ([256], torch.float32, True), - ([3, 256], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([256], torch.float32, True), + ([3, 256], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_or(x, y) -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule()) -def ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic(module, tu: TestUtils): + +@register_test_case( + module_factory=lambda: ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule() +) +def ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic( + module, tu: TestUtils +): module.forward(tu.rand(256), tu.randint(3, 256, low=-1, high=2)) @@ -4299,14 +4605,17 @@ class ElementwiseAtenLogicalAndOpModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_and(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpModule()) def ElementwiseAtenLogicalAndOpModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool()) @@ -4320,15 +4629,20 @@ class ElementwiseAtenLogicalAndOpPromoteBroadcastModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_and(x, y) -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpPromoteBroadcastModule()) + +@register_test_case( + module_factory=lambda: ElementwiseAtenLogicalAndOpPromoteBroadcastModule() +) def ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(256), tu.randint(3, 256, low=-1, high=2)) @@ -4341,16 +4655,23 @@ class ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule(torch.nn.Modu super().__init__() @export - @annotate_args([ - None, - ([256], torch.float32, True), - ([3, 256], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([256], torch.float32, True), + ([3, 256], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_and(x, y) -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule()) -def ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic(module, tu: TestUtils): + +@register_test_case( + module_factory=lambda: ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule() +) +def ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic( + module, tu: TestUtils +): module.forward(tu.rand(256), tu.randint(3, 256, low=-1, high=2)) @@ -4362,14 +4683,17 @@ class ElementwiseAtenLogicalXorOpModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_xor(x, y) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpModule()) def ElementwiseAtenLogicalXorOpModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool()) @@ -4383,15 +4707,20 @@ class ElementwiseAtenLogicalXorOpPromoteBroadcastModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_xor(x, y) -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpPromoteBroadcastModule()) + +@register_test_case( + module_factory=lambda: ElementwiseAtenLogicalXorOpPromoteBroadcastModule() +) def ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(256), tu.randint(3, 256, low=-1, high=2)) @@ -4404,16 +4733,23 @@ class ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule(torch.nn.Modu super().__init__() @export - @annotate_args([ - None, - ([256], torch.float32, True), - ([3, 256], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([256], torch.float32, True), + ([3, 256], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.logical_xor(x, y) -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule()) -def ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic(module, tu: TestUtils): + +@register_test_case( + module_factory=lambda: ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule() +) +def ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic( + module, tu: TestUtils +): module.forward(tu.rand(256), tu.randint(3, 256, low=-1, high=2)) @@ -4425,13 +4761,16 @@ class ElementwiseAtenLogicalNotOpModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x): return torch.ops.aten.logical_not(x) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpModule()) def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=2).bool()) @@ -4439,24 +4778,28 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils): # ============================================================================== + class ElementwiseAtenIsinfOpModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.isinf(x) + @register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule()) def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils): test_input = torch.tensor( [ - [1, float('inf'), 2, float('-inf'), float('nan')], - [1, float('inf'), float('-inf'), float('nan'), 3], + [1, float("inf"), 2, float("-inf"), float("nan")], + [1, float("inf"), float("-inf"), float("nan"), 3], ] ) module.forward(test_input) @@ -4470,19 +4813,22 @@ class ElementwiseAtenIsneginfOpModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([2, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.isneginf(x) + @register_test_case(module_factory=lambda: ElementwiseAtenIsneginfOpModule()) -def ElementwiseAtenIsneginfOpModule_basic(module, tu:TestUtils): +def ElementwiseAtenIsneginfOpModule_basic(module, tu: TestUtils): test_input = torch.tensor( [ - [1, float('-inf'), 2, float('inf'), float('nan')], - [1, float('-inf'), float('inf'), float('nan'), 3], + [1, float("-inf"), 2, float("inf"), float("nan")], + [1, float("-inf"), float("inf"), float("nan"), 3], ] ) module.forward(test_input) @@ -4496,19 +4842,22 @@ class ElementwiseAtenIsposinfOpModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([2, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.isposinf(x) + @register_test_case(module_factory=lambda: ElementwiseAtenIsposinfOpModule()) -def ElementwiseAtenIsposinfOpModule_basic(module, tu:TestUtils): +def ElementwiseAtenIsposinfOpModule_basic(module, tu: TestUtils): test_input = torch.tensor( [ - [1, float('-inf'), 2, float('inf'), float('nan')], - [1, float('-inf'), float('inf'), float('nan'), 3], + [1, float("-inf"), 2, float("inf"), float("nan")], + [1, float("-inf"), float("inf"), float("nan"), 3], ] ) module.forward(test_input) @@ -4522,13 +4871,16 @@ class ElementwiseAtenLogicalNotOpPromoteModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.ops.aten.logical_not(x) + @register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpPromoteModule()) def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, low=-1, high=2)) @@ -4538,15 +4890,16 @@ def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils): class ElementwiseAtenFloorDivideScalarModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.floor_divide(x, 0.14) @@ -4557,84 +4910,93 @@ def ElementwiseAtenFloorDivideScalarModule_basic(module, tu: TestUtils): class ElementwiseAtenFloorDivideScalarNegativeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.floor_divide(x, 0.14) -@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarNegativeModule()) +@register_test_case( + module_factory=lambda: ElementwiseAtenFloorDivideScalarNegativeModule() +) def ElementwiseAtenFloorDivideScalarNegativeModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3, low=-10.0, high=10.0)) - + # ============================================================================== class ElementwiseAtenFloorDivideTensorNegativeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) - def forward(self, x, y): - return torch.ops.aten.floor_divide(x, y) - - -@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorNegativeModule()) -def ElementwiseAtenFloorDivideTensorNegativeModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 3, low= -1, high=0), tu.rand(4, 3, low= 0, high=1)) - - -class ElementwiseAtenFloorDivideTensorPositiveModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) - def forward(self, x, y): - return torch.ops.aten.floor_divide(x, y) - - -@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorPositiveModule()) -def ElementwiseAtenFloorDivideTensorPositiveModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 3), tu.rand(4, 3)) - - -class ElementwiseAtenFloorDivideBroadcastModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.floor_divide(x, y) @register_test_case( - module_factory=lambda: ElementwiseAtenFloorDivideBroadcastModule()) + module_factory=lambda: ElementwiseAtenFloorDivideTensorNegativeModule() +) +def ElementwiseAtenFloorDivideTensorNegativeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3, low=-1, high=0), tu.rand(4, 3, low=0, high=1)) + + +class ElementwiseAtenFloorDivideTensorPositiveModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.floor_divide(x, y) + + +@register_test_case( + module_factory=lambda: ElementwiseAtenFloorDivideTensorPositiveModule() +) +def ElementwiseAtenFloorDivideTensorPositiveModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3), tu.rand(4, 3)) + + +class ElementwiseAtenFloorDivideBroadcastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.floor_divide(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideBroadcastModule()) def ElementwiseAtenFloorDivideBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(4, 3)) @@ -4643,15 +5005,16 @@ def ElementwiseAtenFloorDivideBroadcastModule_basic(module, tu: TestUtils): class AtenTriuModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.triu(x) @@ -4665,15 +5028,16 @@ def AtenTriuModule_basic(module, tu: TestUtils): class AtenTriuWithPosDiagonalModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.triu(x, diagonal=2) @@ -4691,20 +5055,26 @@ class TriuModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4,5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 5], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.triu(x, 1) @register_test_case(module_factory=lambda: TriuModule()) def TriuModule_basic(module, tu: TestUtils): - x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2], - [-0.2447, 0.9556, -1.2919, 1.3378, 0.3], - [ 0.4333, 0.3146, 0.6576, -1.0432, 0.4], - [-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]]) + x = torch.tensor( + [ + [0.5876, -0.0794, -1.8373, 0.6654, 0.2], + [-0.2447, 0.9556, -1.2919, 1.3378, 0.3], + [0.4333, 0.3146, 0.6576, -1.0432, 0.4], + [-0.9888, torch.nan, torch.inf, -torch.inf, 0.5], + ] + ) module.forward(x) @@ -4716,32 +5086,35 @@ class TriuBroadcastModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([3,4,5,6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5, 6], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.triu(x, 2) @register_test_case(module_factory=lambda: TriuBroadcastModule()) def TriuBroadcastModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3,4,5,6)) + module.forward(tu.rand(3, 4, 5, 6)) # ============================================================================== class AtenTriuWithNegDiagonalModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.triu(x, diagonal=-4) @@ -4755,15 +5128,16 @@ def AtenTriuWithNegDiagonalModule_basic(module, tu: TestUtils): class AtenTrilModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.tril(x) @@ -4777,15 +5151,16 @@ def AtenTrilModule_basic(module, tu: TestUtils): class AtenTrilWithPosDiagonalModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.tril(x, diagonal=2) @@ -4799,15 +5174,16 @@ def AtenTrilWithPosDiagonalModule_basic(module, tu: TestUtils): class AtenTrilWithNegDiagonalModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.tril(x, diagonal=-4) @@ -4821,34 +5197,36 @@ def AtenTrilWithNegDiagonalModule_basic(module, tu: TestUtils): class AtenRoundFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.round(x) @register_test_case(module_factory=lambda: AtenRoundFloatModule()) def AtenRoundFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 5, low = -3.0, high = 3.0)) + module.forward(tu.rand(5, 5, low=-3.0, high=3.0)) class AtenRoundFloatHalfToEvenModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.round(x) @@ -4859,37 +5237,39 @@ def AtenRoundFloatHalfToEvenModule_basic(module, tu: TestUtils): class AtenRoundIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.ops.aten.round(x) @register_test_case(module_factory=lambda: AtenRoundIntModule()) def AtenRoundIntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(5, 5, low = -10)) + module.forward(tu.randint(5, 5, low=-10)) # ============================================================================== class Fill_TensorFloat64WithFloat32(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3.0) @@ -4900,15 +5280,16 @@ def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils): class Fill_TensorFloat64WithFloat32Static(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 2, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 2, 4], torch.float32, True), + ] + ) def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3.0) @@ -4919,15 +5300,16 @@ def Fill_TensorFloat64WithFloat32Static_basic(module, tu: TestUtils): class Fill_TensorFloat64WithFloat64(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3.0) @@ -4938,15 +5320,16 @@ def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils): class Fill_TensorFloat64WithInt64(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3) @@ -4957,15 +5340,16 @@ def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils): class Fill_TensorFloat64WithInt64Static(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 2, 4], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([3, 2, 4], torch.float64, True), + ] + ) def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3) @@ -4979,57 +5363,63 @@ def Fill_TensorFloat64WithInt64Static_basic(module, tu: TestUtils): class Fill_TensorFloat32WithFloat32(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, tensor, value): return torch.ops.aten.fill_(tensor, value) + @register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat32()) def Fill_TensorFloat32WithFloat32_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4), tu.rand()) class Fill_TensorFloat32WithFloat64(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([], torch.float64, True), + ] + ) def forward(self, tensor, value): return torch.ops.aten.fill_(tensor, value) + @register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat64()) def Fill_TensorFloat32WithFloat64_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4), tu.rand().to(torch.float64)) class Fill_TensorFloat32WithInt64(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([], torch.int64, True), + ] + ) def forward(self, tensor, value): return torch.ops.aten.fill_(tensor, value) + @register_test_case(module_factory=lambda: Fill_TensorFloat32WithInt64()) def Fill_TensorFloat32WithInt64_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4), tu.randint()) @@ -5043,12 +5433,13 @@ class TupleModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): cond = True if cond: @@ -5068,141 +5459,164 @@ def TupleModule_basic(module, tu: TestUtils): class ElementwiseBitwiseRightShiftInt64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return torch.bitwise_right_shift(lhs, rhs) @register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt64Module()) def ElementwiseBitwiseRightShiftInt64Module_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64)) + module.forward( + tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64) + ) class ElementwiseBitwiseRightShiftInt32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, 4], torch.int32, True), - ([-1, 1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, 4], torch.int32, True), + ([-1, 1], torch.int32, True), + ] + ) def forward(self, lhs, rhs): return torch.bitwise_right_shift(lhs, rhs) @register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt32Module()) def ElementwiseBitwiseRightShiftInt32Module_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32)) + module.forward( + tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), + tu.randint(3, 1, low=0, high=32).to(torch.int32), + ) class ElementwiseBitwiseRightShiftInt8Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ([-1, -1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ] + ) def forward(self, lhs, rhs): return torch.bitwise_right_shift(lhs, rhs) @register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt8Module()) def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8)) + module.forward( + tu.randint(3, 4, low=-100, high=100).to(torch.int8), + tu.randint(3, 4, low=0, high=8).to(torch.int8), + ) # ============================================================================== class ElementwiseBitwiseLeftShiftInt64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return torch.bitwise_left_shift(lhs, rhs) @register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt64Module()) def ElementwiseBitwiseLeftShiftInt64Module_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64)) + module.forward( + tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64) + ) class ElementwiseBitwiseLeftShiftInt32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, 4], torch.int32, True), - ([-1, 1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, 4], torch.int32, True), + ([-1, 1], torch.int32, True), + ] + ) def forward(self, lhs, rhs): return torch.bitwise_left_shift(lhs, rhs) @register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt32Module()) def ElementwiseBitwiseLeftShiftInt32Module_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32)) + module.forward( + tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), + tu.randint(3, 1, low=0, high=32).to(torch.int32), + ) class ElementwiseBitwiseLeftShiftInt8Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ([-1, -1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ] + ) def forward(self, lhs, rhs): return torch.bitwise_left_shift(lhs, rhs) @register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt8Module()) def ElementwiseBitwiseLeftShiftInt8Module_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8)) + module.forward( + tu.randint(3, 4, low=-100, high=100).to(torch.int8), + tu.randint(3, 4, low=0, high=8).to(torch.int8), + ) # ============================================================================== class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.bitwise_and(x, 15) @@ -5213,15 +5627,16 @@ def ElementwiseBitwiseAndScalarInt64Module_basic(module, tu: TestUtils): class ElementwiseBitwiseAndScalarInt32Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.bitwise_and(x, 100) @@ -5232,15 +5647,16 @@ def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils): class ElementwiseBitwiseAndScalarInt8Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ] + ) def forward(self, x): return torch.bitwise_and(x, 100) @@ -5249,18 +5665,21 @@ class ElementwiseBitwiseAndScalarInt8Module(torch.nn.Module): def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8)) + # ============================================================================== -class ElementwiseQuantizePerTensorModule(torch.nn.Module): +class ElementwiseQuantizePerTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float, True), + ] + ) def forward(self, x): scale = 0.04 zp = -110 @@ -5269,22 +5688,26 @@ class ElementwiseQuantizePerTensorModule(torch.nn.Module): q = torch.quantize_per_tensor(x, scale, zp, dtype).int_repr() return q + @register_test_case(module_factory=lambda: ElementwiseQuantizePerTensorModule()) def ElementwiseQuantizePerTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== -class ElementwiseQuantizePerTensorUIntModule(torch.nn.Module): +class ElementwiseQuantizePerTensorUIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float, True), + ] + ) def forward(self, x): scale = 0.04 zp = 11 @@ -5294,6 +5717,7 @@ class ElementwiseQuantizePerTensorUIntModule(torch.nn.Module): q = q.to(torch.int8) return q + @register_test_case(module_factory=lambda: ElementwiseQuantizePerTensorUIntModule()) def ElementwiseQuantizePerTensorUIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) @@ -5301,113 +5725,125 @@ def ElementwiseQuantizePerTensorUIntModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseDequantizePerTensorModule(torch.nn.Module): +class ElementwiseDequantizePerTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ] + ) def forward(self, x): qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8) qx = torch.dequantize(qx) return qx + @register_test_case(module_factory=lambda: ElementwiseDequantizePerTensorModule()) def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8)) + # ============================================================================== -class ElementwiseDequantizePerChannelModule(torch.nn.Module): +class ElementwiseDequantizePerChannelModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int8, True), - ([4], torch.int8, True), - ([4], torch.float, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int8, True), + ([4], torch.int8, True), + ([4], torch.float, True), + ] + ) def forward(self, x, zeropoint, scale): qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1) qx = torch.dequantize(qx) return qx + @register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule()) def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 4, low=-128, high=127).to(torch.int8), tu.randint(4, low=-128, high=127).to(torch.int8), - tu.rand(4) + tu.rand(4), ) + # ============================================================================== + class GluStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 24, 5], torch.float32, True) - ]) + @annotate_args([None, ([3, 24, 5], torch.float32, True)]) def forward(self, x): return torch.ops.aten.glu(x, dim=1) + @register_test_case(module_factory=lambda: GluStaticModule()) def GluStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 24, 5)) + # ============================================================================== + class FakeQuantizePerTensorAffineModule(torch.nn.Module): def __init__(self): super().__init__() + @export - @annotate_args([ - None, - ([4, 50], torch.float32, True) - ]) + @annotate_args([None, ([4, 50], torch.float32, True)]) def forward(self, x): - return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 1, 0, 255) + return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 1, 0, 255) + @register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineModule()) def FakeQuantizePerTensorAffineModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 50)) + class FakeQuantizePerTensorAffineDynamicShapeModule(torch.nn.Module): def __init__(self): super().__init__() - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True) - ]) - def forward(self, x): - return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 1, 0, 255) -@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineDynamicShapeModule()) + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 1, 0, 255) + + +@register_test_case( + module_factory=lambda: FakeQuantizePerTensorAffineDynamicShapeModule() +) def FakeQuantizePerTensorAffineDynamicShapeModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 50)) + class FakeQuantizePerTensorAffineRoundToEvenModule(torch.nn.Module): def __init__(self): super().__init__() - @export - @annotate_args([ - None, - ([4], torch.float32, True) - ]) - def forward(self, x): - return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 0, -128, 127) -@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineRoundToEvenModule()) + @export + @annotate_args([None, ([4], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 0, -128, 127) + + +@register_test_case( + module_factory=lambda: FakeQuantizePerTensorAffineRoundToEvenModule() +) def FakeQuantizePerTensorAffineRoundToEvenModule_basic(module, tu: TestUtils): module.forward(torch.FloatTensor([0.5, 1.5, -0.5, -1.5])) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 6248ef5aa..c283c5455 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class ElementwiseGtFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.gt(x, 0.6) @@ -28,17 +31,21 @@ class ElementwiseGtFloatScalarModule(torch.nn.Module): def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseGtIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.gt(x, 10) @@ -47,17 +54,21 @@ class ElementwiseGtIntScalarModule(torch.nn.Module): def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15)) + # ============================================================================== + class ElementwiseGtMixed2ScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.gt(x, 7) @@ -66,17 +77,21 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module): def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) + # ============================================================================== + class ElementwiseGeFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ge(x, 0.6) @@ -85,17 +100,21 @@ class ElementwiseGeFloatScalarModule(torch.nn.Module): def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseGeIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.ge(x, 10) @@ -104,17 +123,21 @@ class ElementwiseGeIntScalarModule(torch.nn.Module): def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15)) + # ============================================================================== + class ElementwiseGeMixedIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.ge(x, 7) @@ -123,17 +146,21 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module): def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) + # ============================================================================== + class ElementwiseGeFloatIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ge(x, 7) @@ -142,18 +169,22 @@ class ElementwiseGeFloatIntScalarModule(torch.nn.Module): def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseGeFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ge(x, y) @@ -162,20 +193,25 @@ class ElementwiseGeFloatTensorModule(torch.nn.Module): def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), - torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32), + ) + # ============================================================================== + class ElementwiseGeIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ge(x, y) @@ -184,18 +220,22 @@ class ElementwiseGeIntTensorModule(torch.nn.Module): def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) + # ============================================================================== + class ElementwiseGtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.gt(x, y) @@ -204,20 +244,25 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module): def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), - torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32), + ) + # ============================================================================== + class ElementwiseGtIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.gt(x, y) @@ -226,17 +271,21 @@ class ElementwiseGtIntTensorModule(torch.nn.Module): def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) + # ============================================================================== + class ElementwiseLtFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.lt(x, 0.6) @@ -245,17 +294,21 @@ class ElementwiseLtFloatScalarModule(torch.nn.Module): def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseLtIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.lt(x, 0) @@ -264,37 +317,44 @@ class ElementwiseLtIntScalarModule(torch.nn.Module): def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15)) + # ============================================================================== + class ElementwiseLtDiffWidthScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.lt(x, 2) -@register_test_case( - module_factory=lambda: ElementwiseLtDiffWidthScalarModule()) +@register_test_case(module_factory=lambda: ElementwiseLtDiffWidthScalarModule()) def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) + # ============================================================================== + class ElementwiseLeFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.le(x, 0.6) @@ -303,17 +363,21 @@ class ElementwiseLeFloatScalarModule(torch.nn.Module): def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseLeIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.le(x, 10) @@ -322,17 +386,21 @@ class ElementwiseLeIntScalarModule(torch.nn.Module): def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15)) + # ============================================================================== + class ElementwiseLeMixedIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.le(x, 7) @@ -341,17 +409,21 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module): def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) + # ============================================================================== + class ElementwiseLeFloatIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.le(x, 7) @@ -360,18 +432,22 @@ class ElementwiseLeFloatIntScalarModule(torch.nn.Module): def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseLeFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.le(x, y) @@ -380,18 +456,22 @@ class ElementwiseLeFloatTensorModule(torch.nn.Module): def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(5)) + # ============================================================================== + class ElementwiseLeFloatTensorNanModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.le(x, y) @@ -400,20 +480,25 @@ class ElementwiseLeFloatTensorNanModule(torch.nn.Module): def ElementwiseLeFloatTensorNanModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), - torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32), + ) + # ============================================================================== + class ElementwiseLeIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.le(x, y) @@ -422,18 +507,22 @@ class ElementwiseLeIntTensorModule(torch.nn.Module): def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) + # ============================================================================== + class ElementwiseLtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.lt(x, y) @@ -442,20 +531,25 @@ class ElementwiseLtFloatTensorModule(torch.nn.Module): def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), - torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32), + ) + # ============================================================================== + class ElementwiseLtIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.lt(x, y) @@ -464,17 +558,21 @@ class ElementwiseLtIntTensorModule(torch.nn.Module): def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) + # ============================================================================== + class ElementwiseEqFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.eq(x, 6.0) @@ -482,19 +580,24 @@ class ElementwiseEqFloatScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule()) def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils): module.forward( - torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32)) + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32) + ) + # ============================================================================== + class ElementwiseEqIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.eq(x, 2) @@ -503,17 +606,21 @@ class ElementwiseEqIntScalarModule(torch.nn.Module): def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(5, 8, low=2, high=4)) + # ============================================================================== + class ElementwiseEqBoolScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, x): return torch.eq(x, 1) @@ -525,36 +632,42 @@ def ElementwiseEqBoolScalarModule_basic(module, tu: TestUtils): # ============================================================================== + class ElementwiseEqDiffWidthScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, x): return torch.eq(x, 2) -@register_test_case( - module_factory=lambda: ElementwiseEqDiffWidthScalarModule()) +@register_test_case(module_factory=lambda: ElementwiseEqDiffWidthScalarModule()) def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32)) + # ============================================================================== + class ElementwiseEqFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.eq(x, y) @@ -563,20 +676,25 @@ class ElementwiseEqFloatTensorModule(torch.nn.Module): def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, 6.0], [torch.nan, 2.0, 3.1]]).to(torch.float32), - torch.tensor([1.0, 2.4, 6.0]).to(torch.float32)) + torch.tensor([1.0, 2.4, 6.0]).to(torch.float32), + ) + # ============================================================================== + class ElementwiseEqIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.eq(x, y) @@ -585,17 +703,21 @@ class ElementwiseEqIntTensorModule(torch.nn.Module): def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4)) + # ============================================================================== + class ElementwiseNeFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ne(x, 2.0) @@ -603,19 +725,24 @@ class ElementwiseNeFloatScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule()) def ElementwiseNeFloatScalarModule_basic(module, tu: TestUtils): module.forward( - torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32)) + torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32) + ) + # ============================================================================== + class ElementwiseNeIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.ne(x, 3) @@ -624,18 +751,22 @@ class ElementwiseNeIntScalarModule(torch.nn.Module): def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(8, 5, low=2, high=4)) + # ============================================================================== + class ElementwiseNeFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ne(x, y) @@ -644,20 +775,25 @@ class ElementwiseNeFloatTensorModule(torch.nn.Module): def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32), - torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32)) + torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32), + ) + # ============================================================================== + class ElementwiseNeIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ne(x, y) @@ -666,18 +802,22 @@ class ElementwiseNeIntTensorModule(torch.nn.Module): def ElementwiseNeIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4)) + # ============================================================================== + class ElementwiseNeFloatTensorStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3], torch.float32, True), - ([2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ne(x, y) @@ -686,20 +826,25 @@ class ElementwiseNeFloatTensorStaticModule(torch.nn.Module): def ElementwiseNeFloatTensorStaticModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32), - torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32)) + torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32), + ) + # ============================================================================== + class ElementwiseNeIntTensorStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([8, 5], torch.int64, True), - ([5], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([8, 5], torch.int64, True), + ([5], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ne(x, y) @@ -708,16 +853,20 @@ class ElementwiseNeIntTensorStaticModule(torch.nn.Module): def ElementwiseNeIntTensorStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4)) + # ============================================================================== + class AnyBoolTrueModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): input = [False, False, True] return torch.ops.aten.any(input) @@ -733,9 +882,11 @@ class AnyBoolFalseModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): input = [False, False, False] return torch.ops.aten.any(input) @@ -745,17 +896,20 @@ class AnyBoolFalseModule(torch.nn.Module): def AnyBoolFalseModule_basic(module, tu: TestUtils): module.forward() + # ================================================================================= -class AllBoolTrueModule(torch.nn.Module): +class AllBoolTrueModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): input = [True, True, True, True, True] return torch.ops.aten.all(input) @@ -765,36 +919,44 @@ class AllBoolTrueModule(torch.nn.Module): def AllBoolTrueModule_basic(module, tu: TestUtils): module.forward() + # ================================================================================= -class AllBoolFalseModule(torch.nn.Module): +class AllBoolFalseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): input = [True, False, True, True, False] return torch.ops.aten.all(input) + @register_test_case(module_factory=lambda: AllBoolFalseModule()) def AllBoolFalseModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class ElementwiseIsnanModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.isnan(x) @@ -804,17 +966,21 @@ def ElementwiseIsnanModule_basic(module, tu: TestUtils): x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf]) module.forward(x) + # ============================================================================== + class ElementwiseIsinfModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.isinf(x) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py index 815f64eed..7b1699c64 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py @@ -11,92 +11,107 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== -class GridSamplerBasic1(torch.nn.Module): +class GridSamplerBasic1(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([7, 8, 12, 4], torch.float32, True), - ([7, 11, 13, 2], torch.float32, True) - ]) + @annotate_args( + [ + None, + ([7, 8, 12, 4], torch.float32, True), + ([7, 11, 13, 2], torch.float32, True), + ] + ) def forward(self, x, g): - interpolation_mode=0, - padding_mode=0, - align_corners=True, - tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], - padding_mode[0], align_corners[0]) + interpolation_mode = (0,) + padding_mode = (0,) + align_corners = (True,) + tRes = torch.ops.aten.grid_sampler( + x, g, interpolation_mode[0], padding_mode[0], align_corners[0] + ) return tRes -@register_test_case( - module_factory=lambda: GridSamplerBasic1()) -def GridSamplerBasic1_basic( - module, tu: TestUtils): - inp = torch.rand(7,8,12,4) - grd = torch.rand(7,11,13,2)*2.0-1.0 + +@register_test_case(module_factory=lambda: GridSamplerBasic1()) +def GridSamplerBasic1_basic(module, tu: TestUtils): + inp = torch.rand(7, 8, 12, 4) + grd = torch.rand(7, 11, 13, 2) * 2.0 - 1.0 module.forward(inp, grd) class GridSamplerBasic2(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 4, 4], torch.float32, True), - ([1, 1, 3, 2], torch.float32, True) - ]) + @annotate_args( + [None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)] + ) def forward(self, x, g): - interpolation_mode=0, - padding_mode=0, - align_corners=True, - tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], - padding_mode[0], align_corners[0]) + interpolation_mode = (0,) + padding_mode = (0,) + align_corners = (True,) + tRes = torch.ops.aten.grid_sampler( + x, g, interpolation_mode[0], padding_mode[0], align_corners[0] + ) return tRes -@register_test_case( - module_factory=lambda: GridSamplerBasic2()) -def GridSamplerBasic2_basic( - module, tu: TestUtils): - inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320], - [0.3074, 0.6341, 0.4901, 0.8964], - [0.4556, 0.6323, 0.3489, 0.4017], - [0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor) - grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor) + +@register_test_case(module_factory=lambda: GridSamplerBasic2()) +def GridSamplerBasic2_basic(module, tu: TestUtils): + inp = torch.tensor( + [ + [ + [ + [0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017], + [0.0223, 0.1689, 0.2939, 0.5185], + ] + ] + ] + ).type(torch.FloatTensor) + grd = torch.tensor( + [[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]] + ).type(torch.FloatTensor) module.forward(inp, grd) class GridSamplerBasic3(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 4, 4], torch.float32, True), - ([1, 1, 3, 2], torch.float32, True) - ]) + @annotate_args( + [None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)] + ) def forward(self, x, g): - interpolation_mode=0, - padding_mode=0, - align_corners=False, - tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], - padding_mode[0], align_corners[0]) + interpolation_mode = (0,) + padding_mode = (0,) + align_corners = (False,) + tRes = torch.ops.aten.grid_sampler( + x, g, interpolation_mode[0], padding_mode[0], align_corners[0] + ) return tRes -@register_test_case( - module_factory=lambda: GridSamplerBasic3()) -def GridSamplerBasic3_basic( - module, tu: TestUtils): - inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320], - [0.3074, 0.6341, 0.4901, 0.8964], - [0.4556, 0.6323, 0.3489, 0.4017], - [0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor) - grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor) - module.forward(inp, grd) +@register_test_case(module_factory=lambda: GridSamplerBasic3()) +def GridSamplerBasic3_basic(module, tu: TestUtils): + inp = torch.tensor( + [ + [ + [ + [0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017], + [0.0223, 0.1689, 0.2939, 0.5185], + ] + ] + ] + ).type(torch.FloatTensor) + grd = torch.tensor( + [[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]] + ).type(torch.FloatTensor) + module.forward(inp, grd) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py index 9e6e2588b..54426924b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py @@ -16,6 +16,7 @@ NUM_SEGMENTS = 42 NUM_BINS = 5000 NUM_LOGITS = 5000 + class HistogramBinningCalibrationByFeature(torch.nn.Module): def __init__(self): super().__init__() @@ -45,31 +46,34 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module): self._iteration = 0 @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ([-1], torch.int32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ([-1], torch.int32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, segment_value, segment_lengths, logit): - origin_prediction = torch.sigmoid( - logit + torch.log(self.positive_weight)) + origin_prediction = torch.sigmoid(logit + torch.log(self.positive_weight)) # TODO: If in the future this test is removed from xfail for LTC, we will probably hit some device related # issues below when new tensors are created on the CPU, which is currently the default behaviour. # The solution would be to move these tensors to ensure they are on the same device as the existing ones. dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32) validoffsets = torch.gt( - segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits]) + segment_lengths[1 : self._num_logits + 1], + segment_lengths[0 : self._num_logits], + ) gathered_segment_values = ( - segment_value[segment_lengths[0:self._num_logits].long()]+1).int() + segment_value[segment_lengths[0 : self._num_logits].long()] + 1 + ).int() dense_segment_value = torch.where( - validoffsets, gathered_segment_values, dense_segment_value) - zeros = torch.empty_like( - dense_segment_value, dtype=torch.int32).fill_(0) + validoffsets, gathered_segment_values, dense_segment_value + ) + zeros = torch.empty_like(dense_segment_value, dtype=torch.int32).fill_(0) isnotvalid = torch.gt(dense_segment_value, self._num_segments) - dense_segment_value = torch.where( - isnotvalid, zeros, dense_segment_value) - bin_ids_data = torch.ceil(origin_prediction/self.step)-1 + dense_segment_value = torch.where(isnotvalid, zeros, dense_segment_value) + bin_ids_data = torch.ceil(origin_prediction / self.step) - 1 bin_ids_data = bin_ids_data.long() curr_segment_value = dense_segment_value * self._num_bins bin_ids_data2 = bin_ids_data @@ -78,12 +82,14 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module): curr_bin_num_examples = self._bin_num_examples[bin_ids_data] curr_segment_value = curr_segment_value / curr_bin_num_examples curr_segment_value = curr_segment_value.float() - curr_segment_value = curr_segment_value * self.bin_ctr_weight_value + \ - origin_prediction * self.oneminusbin_ctr_weight_value - isvalid = torch.gt(curr_bin_num_examples, - self.bin_ctr_in_use_after) + curr_segment_value = ( + curr_segment_value * self.bin_ctr_weight_value + + origin_prediction * self.oneminusbin_ctr_weight_value + ) + isvalid = torch.gt(curr_bin_num_examples, self.bin_ctr_in_use_after) calibrated_prediction_data = torch.where( - isvalid, curr_segment_value, origin_prediction.float()) + isvalid, curr_segment_value, origin_prediction.float() + ) return calibrated_prediction_data, bin_ids_data @@ -92,11 +98,11 @@ def HBC_basic(module, tu: TestUtils): logits = tu.rand(NUM_LOGITS) segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int) segment_offsets: Tensor = torch.cumsum(segment_lengths, 0) - segment_offsets: Tensor = torch.cat( - (torch.tensor([0]), segment_offsets), 0) + segment_offsets: Tensor = torch.cat((torch.tensor([0]), segment_offsets), 0) num_values: int = int(torch.sum(segment_lengths).item()) segment_values: Tensor = tu.randint(num_values, high=NUM_SEGMENTS) segment_values = torch.cat( - (segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0) + (segment_values, torch.zeros(NUM_LOGITS - segment_values.numel())), 0 + ) module.forward(segment_values.int(), segment_offsets.int(), logits) - #input shape (5000, 5001, 5000) + # input shape (5000, 5001, 5000) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py index c25b563aa..ba0ac1922 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/index_select.py @@ -17,15 +17,17 @@ class IndexSelectSingleIdxModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ([1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ([1], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 1, indices) + @register_test_case(module_factory=lambda: IndexSelectSingleIdxModule()) def IndexSelectSingleIdxModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([2])) @@ -36,33 +38,38 @@ class IndexSelectRank0IdxModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ([], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ([], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 1, indices) + @register_test_case(module_factory=lambda: IndexSelectRank0IdxModule()) def IndexSelectRank0IdxModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor(2)) + class IndexSelectNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ([1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ([1], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, -1, indices) + @register_test_case(module_factory=lambda: IndexSelectNegativeDimModule()) def IndexSelectNegativeDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([2])) @@ -73,15 +80,17 @@ class IndexSelectTwoIdxModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ([2], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ([2], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 2, indices) + @register_test_case(module_factory=lambda: IndexSelectTwoIdxModule()) def IndexSelectTwoIdxModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([2, 4])) @@ -92,15 +101,17 @@ class IndexSelectWholeDimensionModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ([4], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ([4], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 0, indices) + @register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule()) def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([0, 1, 2, 3])) @@ -111,15 +122,17 @@ class IndexSelectWholeTensorModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([3], torch.float32, True), - ([3], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([3], torch.float32, True), + ([3], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 0, indices) + @register_test_case(module_factory=lambda: IndexSelectWholeTensorModule()) def IndexSelectWholeTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3), torch.tensor([0, 1, 2])) @@ -130,15 +143,17 @@ class IndexSelectDynamicModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 2, indices) + @register_test_case(module_factory=lambda: IndexSelectDynamicModule()) def IndexSelectDynamicModulebasic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([0, 4])) @@ -149,15 +164,17 @@ class IndexSelectDynamicInputSizeModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([2], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([2], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 2, indices) + @register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule()) def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([0, 2])) @@ -168,15 +185,17 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ([-1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, indices): return torch.index_select(input, 1, indices) + @register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule()) def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([1, 2])) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 6184b587c..0093f13ce 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -11,16 +11,19 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class MatmulDot(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -29,18 +32,22 @@ class MatmulDot(torch.nn.Module): def Matmul_dot(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(3)) + # ============================================================================== + class Matmul2D(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -49,18 +56,22 @@ class Matmul2D(torch.nn.Module): def Matmul_2d(module, tu: TestUtils): module.forward(tu.rand(3, 4), tu.rand(4, 5)) + # ============================================================================== + class MatmulVecMat(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -69,18 +80,22 @@ class MatmulVecMat(torch.nn.Module): def Matmul_vecmat(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4, 5)) + # ============================================================================== + class MatmulMatVec(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -89,18 +104,22 @@ class MatmulMatVec(torch.nn.Module): def Matmul_matvec(module, tu: TestUtils): module.forward(tu.rand(4, 5), tu.rand(5)) + # ============================================================================== + class Matmul3D(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -109,18 +128,22 @@ class Matmul3D(torch.nn.Module): def Matmul_3d(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) + # ============================================================================== + class Matmul4d(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -129,18 +152,22 @@ class Matmul4d(torch.nn.Module): def Matmul_4d(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) + # ============================================================================== + class Matmul4dStatic(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6, 7], torch.float32, True), - ([4, 5, 7, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 5, 6, 7], torch.float32, True), + ([4, 5, 7, 6], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -149,18 +176,22 @@ class Matmul4dStatic(torch.nn.Module): def Matmul4dStatic_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) + # ============================================================================== + class MatmulStaticBroadcast(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 1, 6, 7], torch.float32, True), - ([8, 1, 5, 7, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 1, 6, 7], torch.float32, True), + ([8, 1, 5, 7, 6], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -169,18 +200,22 @@ class MatmulStaticBroadcast(torch.nn.Module): def MatmulStaticBroadcast_basic(module, tu: TestUtils): module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6)) + # ============================================================================== + class MatmulSingleDynamicBatchDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, -1, -1, -1], torch.float32, True), - ([4, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, -1, -1, -1], torch.float32, True), + ([4, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -189,18 +224,22 @@ class MatmulSingleDynamicBatchDim(torch.nn.Module): def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) + # ============================================================================== + class MatmulBroadcastBatchDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, -1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, -1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, lhs, rhs): return torch.matmul(lhs, rhs) @@ -209,16 +248,19 @@ class MatmulBroadcastBatchDim(torch.nn.Module): def MatmulBroadcastBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6)) + # ============================================================================== -class Mv(torch.nn.Module): +class Mv(torch.nn.Module): @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, m, v): return torch.mv(m, v) @@ -227,16 +269,19 @@ class Mv(torch.nn.Module): def Mv_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2), tu.rand(2)) + # ============================================================================== -class AtenMmFloatTypes(torch.nn.Module): +class AtenMmFloatTypes(torch.nn.Module): @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.mm(a, b) @@ -245,16 +290,19 @@ class AtenMmFloatTypes(torch.nn.Module): def AtenMmFloatTypes_basic(module, tu: TestUtils): module.forward(tu.rand(8, 8), tu.rand(8, 8)) + # ============================================================================== -class AtenMmIntTypes(torch.nn.Module): +class AtenMmIntTypes(torch.nn.Module): @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a, b): return torch.ops.aten.mm(a, b) @@ -266,246 +314,299 @@ def AtenMmIntTypes_basic(module, tu: TestUtils): # ============================================================================== -class AtenMmQint8(torch.nn.Module): +class AtenMmQint8(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int8, True), - ([4, 3], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int8, True), + ([4, 3], torch.int8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) + qz = torch.mm(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMmQint8()) def AtenMmQint8_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), - tu.randint(4, 3, low=-128, high=127).to(torch.int8)) - + module.forward( + tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, 3, low=-128, high=127).to(torch.int8), + ) + + # ============================================================================== -class AtenMmQuint8(torch.nn.Module): +class AtenMmQuint8(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.uint8, True), - ([4, 3], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.uint8, True), + ([4, 3], torch.uint8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160) qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) + qz = torch.mm(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMmQuint8()) def AtenMmQuint8_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=0, high=255).to(torch.uint8), - tu.randint(4, 3, low=0, high=255).to(torch.uint8)) - + module.forward( + tu.randint(3, 4, low=0, high=255).to(torch.uint8), + tu.randint(4, 3, low=0, high=255).to(torch.uint8), + ) + + # ============================================================================== -class AtenMmQMixedSigni8(torch.nn.Module): +class AtenMmQMixedSigni8(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.int8, True), - ([4, 3], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.int8, True), + ([4, 3], torch.uint8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) + qz = torch.mm(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMmQMixedSigni8()) def AtenMmQMixedSigni8_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), - tu.randint(4, 3, low=0, high=255).to(torch.uint8)) - + module.forward( + tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, 3, low=0, high=255).to(torch.uint8), + ) + + # ============================================================================== -class AtenMatmulQint8VM(torch.nn.Module): +class AtenMatmulQint8VM(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int8, True), - ([-1,-1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int8, True), + ([-1, -1], torch.int8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) + qz = torch.matmul(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMatmulQint8VM()) def AtenMatmulQint8VM_basic(module, tu: TestUtils): - module.forward(tu.randint(9, low=-128, high=127).to(torch.int8), - tu.randint(9, 4, low=-128, high=127).to(torch.int8)) - + module.forward( + tu.randint(9, low=-128, high=127).to(torch.int8), + tu.randint(9, 4, low=-128, high=127).to(torch.int8), + ) + + # ============================================================================== class AtenMatmulQint8VV(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int8, True), - ([-1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int8, True), + ([-1], torch.int8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) + qz = torch.matmul(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMatmulQint8VV()) def AtenMatmulQint8VV_basic(module, tu: TestUtils): - module.forward(tu.randint(9, low=-128, high=127).to(torch.int8), - tu.randint(9, low=-128, high=127).to(torch.int8)) - + module.forward( + tu.randint(9, low=-128, high=127).to(torch.int8), + tu.randint(9, low=-128, high=127).to(torch.int8), + ) + + # ============================================================================== class AtenMatmulQint8MV(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int8, True), - ([-1], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int8, True), + ([-1], torch.int8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) + qz = torch.matmul(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMatmulQint8MV()) def AtenMatmulQint8MV_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), - tu.randint(4, low=-128, high=127).to(torch.int8)) - + module.forward( + tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, low=-128, high=127).to(torch.int8), + ) + + # ============================================================================== class AtenMatmulQint8(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, -1, 3, 4], torch.int8, True), - ([-1, 4, 3], torch.int8, True), - ]) + @annotate_args( + [ + None, + ([4, -1, 3, 4], torch.int8, True), + ([-1, 4, 3], torch.int8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) + qz = torch.matmul(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMatmulQint8()) def AtenMatmulQint8_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 7, 3, 4, low=-128, high=127).to(torch.int8), - tu.randint(7, 4, 3, low=-128, high=127).to(torch.int8)) - + module.forward( + tu.randint(4, 7, 3, 4, low=-128, high=127).to(torch.int8), + tu.randint(7, 4, 3, low=-128, high=127).to(torch.int8), + ) + + # ============================================================================== -class AtenMatmulQMixedSigni8(torch.nn.Module): +class AtenMatmulQMixedSigni8(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([7, -1, -1, -1], torch.int8, True), - ([-1, -1, -1], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([7, -1, -1, -1], torch.int8, True), + ([-1, -1, -1], torch.uint8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) + qz = torch.matmul(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8()) def AtenMatmulQMixedSigni8_basic(module, tu: TestUtils): - module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8), - tu.randint(2, 4, 3, low=0, high=255).to(torch.uint8)) - + module.forward( + tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8), + tu.randint(2, 4, 3, low=0, high=255).to(torch.uint8), + ) + + # ============================================================================== -class AtenMatmulQMixedSigni8Transpose(torch.nn.Module): +class AtenMatmulQMixedSigni8Transpose(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([7, -1, -1, -1], torch.int8, True), - ([-1, -1, -1], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([7, -1, -1, -1], torch.int8, True), + ([-1, -1, -1], torch.uint8, True), + ] + ) def forward(self, x, y): qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) qx = torch.dequantize(qx) qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) qy = torch.dequantize(qy) qy = torch.transpose(qy, 1, 2) - qz = torch.matmul(qx, qy) + qz = torch.matmul(qx, qy) return qz + @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose()) def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils): - module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8), - tu.randint(2, 6, 4, low=0, high=255).to(torch.uint8)) - + module.forward( + tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8), + tu.randint(2, 6, 4, low=0, high=255).to(torch.uint8), + ) + + # ============================================================================== -class AtenLinalgCrossInt(torch.nn.Module): +class AtenLinalgCrossInt(torch.nn.Module): @export - @annotate_args([ - None, - ([2, 3], torch.int64, True), - ([2, 3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.int64, True), + ] + ) def forward(self, a, b): return torch.ops.aten.linalg_cross(a, b) @@ -514,16 +615,19 @@ class AtenLinalgCrossInt(torch.nn.Module): def AtenLinalgCrossInt_basic(module, tu: TestUtils): module.forward(tu.randint(2, 3), tu.randint(2, 3)) + # ============================================================================== -class AtenLinalgCrossFloat(torch.nn.Module): +class AtenLinalgCrossFloat(torch.nn.Module): @export - @annotate_args([ - None, - ([2, 3], torch.float32, True), - ([2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.linalg_cross(a, b) @@ -535,14 +639,16 @@ def AtenLinalgCrossFloat_basic(module, tu: TestUtils): # ============================================================================== -class AtenLinalgCrossBroadcast(torch.nn.Module): +class AtenLinalgCrossBroadcast(torch.nn.Module): @export - @annotate_args([ - None, - ([1, 4, 3], torch.float32, True), - ([5, 4, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 4, 3], torch.float32, True), + ([5, 4, 3], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.linalg_cross(a, b) @@ -551,16 +657,19 @@ class AtenLinalgCrossBroadcast(torch.nn.Module): def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils): module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3)) + # ============================================================================== -class AtenLinalgCrossCustomDim(torch.nn.Module): +class AtenLinalgCrossCustomDim(torch.nn.Module): @export - @annotate_args([ - None, - ([1, 4, 3, 2, 2], torch.float32, True), - ([5, 4, 3, 2, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.linalg_cross(a, b, dim=2) @@ -569,16 +678,19 @@ class AtenLinalgCrossCustomDim(torch.nn.Module): def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils): module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + # ============================================================================== -class AtenLinalgCrossNegativeDim(torch.nn.Module): +class AtenLinalgCrossNegativeDim(torch.nn.Module): @export - @annotate_args([ - None, - ([1, 4, 3, 2, 2], torch.float32, True), - ([5, 4, 3, 2, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.linalg_cross(a, b, dim=-3) @@ -587,22 +699,26 @@ class AtenLinalgCrossNegativeDim(torch.nn.Module): def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils): module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + # ============================================================================== + class AtenLinalgCrossDynamic(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.linalg_cross(a, b, dim=1) @register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) \ No newline at end of file + module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py index faebcadf3..96019742c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py @@ -14,6 +14,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # Multi-layer perceptron (MLP) models. + class Mlp1LayerModule(torch.nn.Module): def __init__(self): super().__init__() @@ -21,18 +22,23 @@ class Mlp1LayerModule(torch.nn.Module): torch.manual_seed(0) self.fc0 = nn.Linear(3, 5) self.tanh0 = nn.Tanh() + @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.tanh0(self.fc0(x)) + @register_test_case(module_factory=lambda: Mlp1LayerModule()) def Mlp1LayerModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 3)) + class Mlp2LayerModule(torch.nn.Module): def __init__(self): super().__init__() @@ -43,20 +49,25 @@ class Mlp2LayerModule(torch.nn.Module): self.tanh0 = nn.Tanh() self.fc1 = nn.Linear(N_HIDDEN, 2) self.tanh1 = nn.Tanh() + @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): x = self.tanh0(self.fc0(x)) x = self.tanh1(self.fc1(x)) return x + @register_test_case(module_factory=lambda: Mlp2LayerModule()) def Mlp2LayerModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 3)) + class Mlp2LayerModuleNoBias(torch.nn.Module): def __init__(self): super().__init__() @@ -67,20 +78,25 @@ class Mlp2LayerModuleNoBias(torch.nn.Module): self.tanh0 = nn.Tanh() self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False) self.tanh1 = nn.Tanh() + @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): x = self.tanh0(self.fc0(x)) x = self.tanh1(self.fc1(x)) return x + @register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias()) def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils): module.forward(tu.rand(5, 3)) + class BatchMlpLayerModule(torch.nn.Module): def __init__(self): super().__init__() @@ -88,14 +104,18 @@ class BatchMlpLayerModule(torch.nn.Module): torch.manual_seed(0) self.fc0 = nn.Linear(3, 5) self.tanh0 = nn.Tanh() + @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.tanh0(self.fc0(x)) + @register_test_case(module_factory=lambda: BatchMlpLayerModule()) def BatchMlpLayerModule_basic(module, tu: TestUtils): module.forward(tu.rand(7, 5, 3)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py index 0cbe1c5fd..675d04249 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py @@ -13,506 +13,561 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class NllLossModule(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ]) - # Here the 2nd index is ignored. - def forward(self, x, y): - return torch.ops.aten.nll_loss_forward(x, - target=y, - weight=None, - reduction=0, - ignore_index=2) + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=0, ignore_index=2 + ) @register_test_case(module_factory=lambda: NllLossModule()) def NllLossModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_mean(torch.nn.Module): - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ]) - # Here the 2nd index is ignored. - def forward(self, x, y): - return torch.ops.aten.nll_loss_forward(x, - target=y, - weight=None, - reduction=1, - ignore_index=2) + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=1, ignore_index=2 + ) @register_test_case(module_factory=lambda: NllLossModule_mean()) def NllLossModule_mean_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_sum(torch.nn.Module): - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ]) - # Here the 2nd index is ignored. - def forward(self, x, y): - return torch.ops.aten.nll_loss_forward(x, - target=y, - weight=None, - reduction=2, - ignore_index=2) + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=2, ignore_index=2 + ) @register_test_case(module_factory=lambda: NllLossModule_sum()) def NllLossModule_sum_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_1D(torch.nn.Module): - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([], torch.int64, True), - ]) - # Here the 2nd index is ignored. - def forward(self, x, y): - return torch.ops.aten.nll_loss_forward(x, - target=y, - weight=None, - reduction=0, - ignore_index=2) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=0, ignore_index=2 + ) @register_test_case(module_factory=lambda: NllLossModule_1D()) def NllLossModule_1D_basic(module, tu: TestUtils): - module.forward(tu.rand(3), tu.randint(high=3)) + module.forward(tu.rand(3), tu.randint(high=3)) class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ]) - # None of the index is ignored here, since the ignored index is out of bounds. - def forward(self, x, y): - return torch.ops.aten.nll_loss_forward(x, - target=y, - weight=None, - reduction=0, - ignore_index=10) + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ] + ) + # None of the index is ignored here, since the ignored index is out of bounds. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=0, ignore_index=10 + ) @register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds()) def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + class NllLossModule_backward(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=None, - reduction=0, - ignore_index=10, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=None, + reduction=0, + ignore_index=10, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backward()) def NllLossModuleBackward_basic(module, tu: TestUtils): - module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), - torch.tensor(3.)) + module.forward( + tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0) + ) class NllLossModule_backwardWeight(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, weight, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=weight, - reduction=0, - ignore_index=10, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=weight, + reduction=0, + ignore_index=10, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backwardWeight()) def NllLossModuleBackwardWeight_basic(module, tu: TestUtils): - module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), - tu.rand(4), torch.tensor(3.)) - + module.forward( + tu.rand(3), + tu.rand(3, 4), + torch.tensor([2, 3, 0]), + tu.rand(4), + torch.tensor(3.0), + ) class NllLossModule_backward_ignore_index(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=None, - reduction=0, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=None, + reduction=0, + ignore_index=1, + total_weight=total_weight, + ) -@register_test_case( - module_factory=lambda: NllLossModule_backward_ignore_index()) +@register_test_case(module_factory=lambda: NllLossModule_backward_ignore_index()) def NllLossModuleBackward_ignore_index(module, tu: TestUtils): - module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), - torch.tensor(3.)) + module.forward( + tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0) + ) class NllLossModule_backwardMean(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=None, - reduction=1, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=None, + reduction=1, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backwardMean()) def NllLossModuleBackwardMean_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), - torch.tensor(3.)) + module.forward( + tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0) + ) class NllLossModule_backwardMeanWeight(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, weight, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=weight, - reduction=1, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=weight, + reduction=1, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight()) def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), - tu.rand(4), torch.tensor(3.)) + module.forward( + tu.rand(1), + tu.rand(3, 4), + torch.tensor([2, 3, 0]), + tu.rand(4), + torch.tensor(3.0), + ) class NllLossModule_backwardSum(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=None, - reduction=2, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=None, + reduction=2, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backwardSum()) def NllLossModuleBackwardSum_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), - torch.tensor(3.)) + module.forward( + tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0) + ) class NllLossModule_backwardSumWeight(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, weight, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=weight, - reduction=2, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=weight, + reduction=2, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight()) def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), - tu.rand(4), torch.tensor(3.)) + module.forward( + tu.rand(1), + tu.rand(3, 4), + torch.tensor([2, 3, 0]), + tu.rand(4), + torch.tensor(3.0), + ) class NllLossModule_backward1D(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=None, - reduction=0, - ignore_index=10, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=None, + reduction=0, + ignore_index=10, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backward1D()) def NllLossModuleBackward1D_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - torch.tensor(3.)) + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0)) class NllLossModule_backward1DWeight(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, weight, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=weight, - reduction=0, - ignore_index=10, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=weight, + reduction=0, + ignore_index=10, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backward1DWeight()) def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - tu.rand(3), torch.tensor(3.)) + module.forward( + tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0) + ) class NllLossModule_backward1DMean(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=None, - reduction=1, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=None, + reduction=1, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backward1DMean()) def NllLossModuleBackward1DMean_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - torch.tensor(3.)) + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0)) class NllLossModule_backward1DMeanWeight(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, weight, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=weight, - reduction=1, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=weight, + reduction=1, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight()) def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - tu.rand(3), torch.tensor(3.)) + module.forward( + tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0) + ) class NllLossModule_backward1DSum(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=None, - reduction=2, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=None, + reduction=2, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backward1DSum()) def NllLossModuleBackward1DSum_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - torch.tensor(3.)) + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0)) class NllLossModule_backward1DSumWeight(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) - def forward(self, grad_output, input, target, weight, total_weight): - return torch.ops.aten.nll_loss_backward(grad_output, - input, - target=target, - weight=weight, - reduction=2, - ignore_index=1, - total_weight=total_weight) + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward( + grad_output, + input, + target=target, + weight=weight, + reduction=2, + ignore_index=1, + total_weight=total_weight, + ) @register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight()) def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils): - module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - tu.rand(3), torch.tensor(3.)) + module.forward( + tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0) + ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index 56821fb69..f4c9e39d1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -11,6 +11,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class BatchNorm1DModule(torch.nn.Module): def __init__(self): super().__init__() @@ -18,15 +19,16 @@ class BatchNorm1DModule(torch.nn.Module): self.bn1d.eval() self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6]) self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0]) - self.bn1d.weight = torch.nn.Parameter( - torch.tensor([3.0, 2.0, 4.0, 5.0])) + self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0])) self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6])) @export - @annotate_args([ - None, - ([10, 4, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([10, 4, 3], torch.float32, True), + ] + ) def forward(self, x): return self.bn1d(x) @@ -35,8 +37,10 @@ class BatchNorm1DModule(torch.nn.Module): def BatchNorm1DModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 4, 3)) + # ============================================================================== + class BatchNorm1DWith2DInputModule(torch.nn.Module): def __init__(self): super().__init__() @@ -44,15 +48,16 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module): self.bn1d.eval() self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6]) self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0]) - self.bn1d.weight = torch.nn.Parameter( - torch.tensor([3.0, 2.0, 4.0, 5.0])) + self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0])) self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6])) @export - @annotate_args([ - None, - ([10, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([10, 4], torch.float32, True), + ] + ) def forward(self, x): return self.bn1d(x) @@ -61,8 +66,10 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module): def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 4)) + # ============================================================================== + class BatchNorm2DModule(torch.nn.Module): def __init__(self): super().__init__() @@ -74,10 +81,12 @@ class BatchNorm2DModule(torch.nn.Module): self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4])) @export - @annotate_args([ - None, - ([10, 2, 3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([10, 2, 3, 3], torch.float32, True), + ] + ) def forward(self, x): return self.bn2d(x) @@ -86,8 +95,10 @@ class BatchNorm2DModule(torch.nn.Module): def BatchNorm2DModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 2, 3, 3)) + # ============================================================================== + class BatchNorm3DModule(torch.nn.Module): def __init__(self): super().__init__() @@ -95,16 +106,16 @@ class BatchNorm3DModule(torch.nn.Module): self.bn3d.eval() self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]) self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]) - self.bn3d.weight = torch.nn.Parameter( - torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])) - self.bn3d.bias = torch.nn.Parameter( - torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])) + self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])) + self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])) @export - @annotate_args([ - None, - ([2, 5, 3, 6, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5, 3, 6, 4], torch.float32, True), + ] + ) def forward(self, x): return self.bn3d(x) @@ -113,274 +124,361 @@ class BatchNorm3DModule(torch.nn.Module): def BatchNorm3DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 3, 6, 4)) + # ============================================================================== + class BatchNorm1DStaticShapeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 5], torch.float32, True), - ([5], torch.float32, True), - ([5], torch.float32, True), - ([5], torch.float32, True), - ([5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ([5], torch.float32, True), + ([5], torch.float32, True), + ([5], torch.float32, True), + ([5], torch.float32, True), + ] + ) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.batch_norm( - x, weight, bias, running_mean, running_var, training=False, - momentum=0.1, eps=0.00001, cudnn_enabled=False) + x, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=0.00001, + cudnn_enabled=False, + ) @register_test_case(module_factory=lambda: BatchNorm1DStaticShapeModule()) def BatchNorm1DStaticShapeModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + module.forward(tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + # ============================================================================== + class NativeBatchNorm1DModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, - momentum=0.1, eps=0.00001) + x, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=0.00001, + ) @register_test_case(module_factory=lambda: NativeBatchNorm1DModule()) def NativeBatchNorm1DModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + module.forward(tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + # ============================================================================== + class NativeBatchNorm2DModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, - momentum=0.1, eps=0.00001) + x, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=0.00001, + ) @register_test_case(module_factory=lambda: NativeBatchNorm2DModule()) def NativeBatchNorm2DModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + module.forward(tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + # ============================================================================== + class NativeBatchNorm3DModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, - momentum=0.1, eps=0.00001) + x, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=0.00001, + ) @register_test_case(module_factory=lambda: NativeBatchNorm3DModule()) def NativeBatchNorm3DModule_basic(module, tu: TestUtils): module.forward( - tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5) + ) + # ============================================================================== + class NativeBatchNormNoneWeightModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, x, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, None, bias, running_mean, running_var, training=False, - momentum=0.1, eps=0.00001) + x, + None, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=0.00001, + ) @register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule()) def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5)) + # ============================================================================== + class GroupNormModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4, 6, 7], torch.float32, True), - ([4], torch.float32, True), - ([4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 4, 6, 7], torch.float32, True), + ([4], torch.float32, True), + ([4], torch.float32, True), + ] + ) def forward(self, x, weight, bias): - return torch.ops.aten.group_norm(x, 2, weight, bias, 1.0000000000000001e-05, False) + return torch.ops.aten.group_norm( + x, 2, weight, bias, 1.0000000000000001e-05, False + ) + @register_test_case(module_factory=lambda: GroupNormModule()) def GroupNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4)) + class GroupNormNoWeightAndBiasModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4, 6, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 4, 6, 7], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.group_norm(x, 2, None, None, 1.0000000000000001e-05, False) + return torch.ops.aten.group_norm( + x, 2, None, None, 1.0000000000000001e-05, False + ) + @register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule()) def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 6, 7)) + # ============================================================================== + class NativeGroupNormModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 6, 2, 2], torch.float32, True), - ([6], torch.float32, True), - ([6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 6, 2, 2], torch.float32, True), + ([6], torch.float32, True), + ([6], torch.float32, True), + ] + ) def forward(self, x, weight, bias): - return torch.ops.aten.native_group_norm( - x, weight, bias, - 2, 6, 4, 3, 0.000001) + return torch.ops.aten.native_group_norm(x, weight, bias, 2, 6, 4, 3, 0.000001) @register_test_case(module_factory=lambda: NativeGroupNormModule()) def NativeGroupNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6)) + # ============================================================================== + class NativeGroupNormBackwardModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 6, 2, 2], torch.float32, True), - ([2, 6, 2, 2], torch.float32, True), - ([2, 3], torch.float32, True), - ([2, 3], torch.float32, True), - ([6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 6, 2, 2], torch.float32, True), + ([2, 6, 2, 2], torch.float32, True), + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ([6], torch.float32, True), + ] + ) def forward(self, grad_out, x, mean, rstd, weight): return torch.ops.aten.native_group_norm_backward( - grad_out, x, mean, rstd, weight, - 2, 6, 4, 3, [True, True, True]) + grad_out, x, mean, rstd, weight, 2, 6, 4, 3, [True, True, True] + ) @register_test_case(module_factory=lambda: NativeGroupNormBackwardModule()) def NativeGroupNormBackwardModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 6, 2, 2), tu.rand(2, 6, 2, 2), tu.rand(2, 3), - tu.rand(2, 3), tu.rand(6)) + module.forward( + tu.rand(2, 6, 2, 2), + tu.rand(2, 6, 2, 2), + tu.rand(2, 3), + tu.rand(2, 3), + tu.rand(6), + ) + # ============================================================================== + class NativeLayerNormModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 5, 2, 2, 3], torch.float32, True), - ([2, 2, 3], torch.float32, True), - ([2, 2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5, 2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ] + ) def forward(self, x, weight, bias): list = [2, 2, 3] - return torch.ops.aten.native_layer_norm( - x, list, weight, bias, eps=0.5) + return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5) @register_test_case(module_factory=lambda: NativeLayerNormModule()) def NativeLayerNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3)) + class NativeLayerNormDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, weight, bias): list = [2, 2, 3] - return torch.ops.aten.native_layer_norm( - x, list, weight, bias, eps=0.5) + return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5) @register_test_case(module_factory=lambda: NativeLayerNormDynamicModule()) def NativeLayerNormDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3)) + # ============================================================================== + class NormalizeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 3], torch.float32, True), + ] + ) def forward(self, x): return torch.nn.functional.normalize(x) @@ -389,48 +487,59 @@ class NormalizeModule(torch.nn.Module): def NormalizeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3)) + # ============================================================================== + class NativeLayerNormModule4D(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 2, 2, 3], torch.float32, True), - ([2, 2, 3], torch.float32, True), - ([2, 2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ] + ) def forward(self, x, weight, bias): list = [2, 2, 3] - return torch.ops.aten.native_layer_norm( - x, list, weight, bias, eps=0.5)[0] + return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)[0] @register_test_case(module_factory=lambda: NativeLayerNormModule4D()) def NativeLayerNormModule4D_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3)) + # ============================================================================== + class LayerNormModule(torch.nn.Module): def __init__(self): super().__init__() self.ly = torch.nn.LayerNorm([2, 2, 3]) self.ly.eval() self.ly.weight = torch.nn.Parameter( - torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], - [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]])) + torch.tensor( + [[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]] + ) + ) self.ly.bias = torch.nn.Parameter( - torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], - [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]])) + torch.tensor( + [[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]] + ) + ) @export - @annotate_args([ - None, - ([2, 5, 2, 2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5, 2, 2, 3], torch.float32, True), + ] + ) def forward(self, x): return self.ly(x) @@ -439,8 +548,10 @@ class LayerNormModule(torch.nn.Module): def LayerNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 2, 2, 3)) + # ============================================================================== + class LayerNormLastDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -450,10 +561,12 @@ class LayerNormLastDimModule(torch.nn.Module): self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3])) @export - @annotate_args([ - None, - ([2, 5, 2, 2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 5, 2, 2, 3], torch.float32, True), + ] + ) def forward(self, x): return self.ly(x) @@ -462,25 +575,33 @@ class LayerNormLastDimModule(torch.nn.Module): def LayerNormLastDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 2, 2, 3)) + # ============================================================================== + class LayerNormNormalizeOverAllDimsModule(torch.nn.Module): def __init__(self): super().__init__() self.ly = torch.nn.LayerNorm([2, 2, 3]) self.ly.eval() self.ly.weight = torch.nn.Parameter( - torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], - [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]])) + torch.tensor( + [[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]] + ) + ) self.ly.bias = torch.nn.Parameter( - torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], - [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]])) + torch.tensor( + [[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]] + ) + ) @export - @annotate_args([ - None, - ([2, 2, 3], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 2, 3], torch.float32, True), + ] + ) def forward(self, x): return self.ly(x) @@ -489,20 +610,25 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module): def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 3)) + class AtenInstanceNormModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 2, 1, 3], torch.float32, True), - ([2], torch.float32, True), - ([2], torch.float32, True) - ]) + @annotate_args( + [ + None, + ([1, 2, 1, 3], torch.float32, True), + ([2], torch.float32, True), + ([2], torch.float32, True), + ] + ) def forward(self, x, w, b): - return torch.ops.aten.instance_norm(x, w, b, None, - None, True, 0.0, 1e-05, False) + return torch.ops.aten.instance_norm( + x, w, b, None, None, True, 0.0, 1e-05, False + ) + @register_test_case(module_factory=lambda: AtenInstanceNormModule()) def AtenInstanceNormModule_basic(module, tu: TestUtils): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index 59961fedc..a97d7f09e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -12,98 +12,112 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== -class ReflectionPad2dModule(torch.nn.Module): +class ReflectionPad2dModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 20, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 20, 20], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad2d(x, (10,10,10,10)) + return torch.ops.aten.reflection_pad2d(x, (10, 10, 10, 10)) @register_test_case(module_factory=lambda: ReflectionPad2dModule()) def ReflectionPad2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 20, 20, low=-1)) + # ============================================================================== -class ReflectionPad2dModuleTop(torch.nn.Module): +class ReflectionPad2dModuleTop(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 3, 4], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad2d(x, (0,0,2,0)) + return torch.ops.aten.reflection_pad2d(x, (0, 0, 2, 0)) @register_test_case(module_factory=lambda: ReflectionPad2dModuleTop()) def ReflectionPad2dModule_Top(module, tu: TestUtils): module.forward(tu.rand(1, 3, 4)) + # ============================================================================== -class ReflectionPad2dModuleBottom(torch.nn.Module): +class ReflectionPad2dModuleBottom(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 10, 10], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 10, 10], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad2d(x, (0,0,0,5)) + return torch.ops.aten.reflection_pad2d(x, (0, 0, 0, 5)) @register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom()) def ReflectionPad2dModule_Bottom(module, tu: TestUtils): module.forward(tu.rand(2, 3, 10, 10)) + # ============================================================================== -class ReflectionPad2dModuleLeft(torch.nn.Module): +class ReflectionPad2dModuleLeft(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 20, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 20, 20], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad2d(x, (15,0,0,0)) + return torch.ops.aten.reflection_pad2d(x, (15, 0, 0, 0)) @register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft()) def ReflectionPad2dModule_Left(module, tu: TestUtils): module.forward(tu.rand(2, 3, 20, 20)) + # ============================================================================== -class ReflectionPad2dModuleRight(torch.nn.Module): +class ReflectionPad2dModuleRight(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 20, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 20, 20], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.reflection_pad2d(x, (0,11,0,0)) + return torch.ops.aten.reflection_pad2d(x, (0, 11, 0, 0)) @register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index a54023254..50711afed 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -11,130 +11,145 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== -class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module): +class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7)) @export - @annotate_args([ - None, - ([1, 512, 7, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 512, 7, 7], torch.float32, True), + ] + ) def forward(self, x): return self.aap2d(x) @register_test_case( - module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeStaticModule()) -def AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic( - module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeStaticModule() +) +def AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7, 7)) class AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule(torch.nn.Module): - def __init__(self): super().__init__() self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7)) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.aap2d(x) @register_test_case( - module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule()) -def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic( - module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule() +) +def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7, 7)) class AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule(torch.nn.Module): - def __init__(self): super().__init__() - self.aap2d = torch.nn.AdaptiveAvgPool2d((5, 7)) - + self.aap2d = torch.nn.AdaptiveAvgPool2d((5, 7)) + @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.aap2d(x) @register_test_case( - module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule()) -def AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic(module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule() +) +def AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic( + module, tu: TestUtils +): module.forward(tu.rand(1, 512, 15, 28)) class AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule(torch.nn.Module): - def __init__(self): super().__init__() - self.aap2d = torch.nn.AdaptiveAvgPool2d((3, 7)) - + self.aap2d = torch.nn.AdaptiveAvgPool2d((3, 7)) + @export - @annotate_args([ - None, - ([1, 512, 15, 14], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 512, 15, 14], torch.float32, True), + ] + ) def forward(self, x): return self.aap2d(x) @register_test_case( - module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule()) -def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic(module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule() +) +def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic( + module, tu: TestUtils +): module.forward(tu.rand(1, 512, 15, 14)) class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module): - def __init__(self): super().__init__() self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1)) @export - @annotate_args([ - None, - ([1, 512, 7, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 512, 7, 7], torch.float32, True), + ] + ) def forward(self, x): return self.aap2d(x) @register_test_case( - module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeStaticModule()) + module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeStaticModule() +) def AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7, 7)) class AdaptiveAvgPool2dUnitOutputSizeDynamicModule(torch.nn.Module): - def __init__(self): super().__init__() self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1)) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.aap2d(x) @register_test_case( - module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeDynamicModule()) + module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeDynamicModule() +) def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7, 7)) @@ -143,19 +158,19 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): class MaxPool2dModule(torch.nn.Module): - def __init__(self): super().__init__() - self.mp2d = torch.nn.MaxPool2d(kernel_size=[6, 8], - stride=[2, 2], - padding=[3, 4], - dilation=2) + self.mp2d = torch.nn.MaxPool2d( + kernel_size=[6, 8], stride=[2, 2], padding=[3, 4], dilation=2 + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.mp2d(x) @@ -166,15 +181,16 @@ def MaxPool2dModule_basic(module, tu: TestUtils): class MaxPool2dEmptyStrideStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 20, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 20, 20], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.max_pool2d(x, kernel_size=2, stride=[]) @@ -185,19 +201,19 @@ def MaxPool2dEmptyStrideStaticModule_basic(module, tu: TestUtils): class MaxPool2dStaticModule(torch.nn.Module): - def __init__(self): super().__init__() - self.mp2d = torch.nn.MaxPool2d(kernel_size=[3, 3], - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1]) + self.mp2d = torch.nn.MaxPool2d( + kernel_size=[3, 3], stride=[2, 2], padding=[1, 1], dilation=[1, 1] + ) @export - @annotate_args([ - None, - ([1, 64, 112, 112], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 64, 112, 112], torch.float32, True), + ] + ) def forward(self, x): return self.mp2d(x) @@ -206,21 +222,25 @@ class MaxPool2dStaticModule(torch.nn.Module): def MaxPool2dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 64, 112, 112)) -class MaxPool2dStaticCeilModeTrueModule(torch.nn.Module): +class MaxPool2dStaticCeilModeTrueModule(torch.nn.Module): def __init__(self): super().__init__() - self.mp2d = torch.nn.MaxPool2d(kernel_size=[3, 3], - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1], - ceil_mode=True) + self.mp2d = torch.nn.MaxPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + ceil_mode=True, + ) @export - @annotate_args([ - None, - ([1, 64, 112, 112], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 64, 112, 112], torch.float32, True), + ] + ) def forward(self, x): return self.mp2d(x) @@ -231,20 +251,23 @@ def MaxPool2dStaticCeilModeTrueModule_basic(module, tu: TestUtils): class MaxPool2dCeilModeTrueModule(torch.nn.Module): - def __init__(self): super().__init__() - self.mp2d = torch.nn.MaxPool2d(kernel_size=[6, 8], - stride=[2, 2], - padding=[3, 4], - dilation=2, - ceil_mode=True) + self.mp2d = torch.nn.MaxPool2d( + kernel_size=[6, 8], + stride=[2, 2], + padding=[3, 4], + dilation=2, + ceil_mode=True, + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.mp2d(x) @@ -256,41 +279,44 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): # ============================================================================== -class MaxPool3dModule(torch.nn.Module): +class MaxPool3dModule(torch.nn.Module): def __init__(self): super().__init__() - self.mp3d = torch.nn.MaxPool3d(kernel_size=[4, 4, 4], - stride=[2, 2, 2], - padding=[1, 1, 1], - dilation=1) + self.mp3d = torch.nn.MaxPool3d( + kernel_size=[4, 4, 4], stride=[2, 2, 2], padding=[1, 1, 1], dilation=1 + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.mp3d(x) @register_test_case(module_factory=lambda: MaxPool3dModule()) def MaxPool3dModule_basic(module, tu: TestUtils): - module.forward(torch.arange(8*8*8).view(1, 1, 8, 8, 8).float()) + module.forward(torch.arange(8 * 8 * 8).view(1, 1, 8, 8, 8).float()) + class MaxPool3dRandomSimpleModule(torch.nn.Module): def __init__(self): super().__init__() - self.mp3d = torch.nn.MaxPool3d(kernel_size=[4, 4, 4], - stride=[2, 2, 2], - padding=[1, 1, 1], - dilation=1) + self.mp3d = torch.nn.MaxPool3d( + kernel_size=[4, 4, 4], stride=[2, 2, 2], padding=[1, 1, 1], dilation=1 + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.mp3d(x) @@ -299,20 +325,21 @@ class MaxPool3dRandomSimpleModule(torch.nn.Module): def MaxPool3dModuleRandomSimple_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) -class MaxPool3dLargeDataModule(torch.nn.Module): +class MaxPool3dLargeDataModule(torch.nn.Module): def __init__(self): super().__init__() - self.mp3d = torch.nn.MaxPool3d(kernel_size=[6, 8, 8], - stride=[2, 2, 2], - padding=[3, 4, 4], - dilation=2) + self.mp3d = torch.nn.MaxPool3d( + kernel_size=[6, 8, 8], stride=[2, 2, 2], padding=[3, 4, 4], dilation=2 + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.mp3d(x) @@ -321,14 +348,18 @@ class MaxPool3dLargeDataModule(torch.nn.Module): def MaxPool3dLargeDatadModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) + class MaxPool3dEmptyStrideStaticModule(torch.nn.Module): def __init__(self): super().__init__() + @export - @annotate_args([ - None, - ([1, 1, 20, 20, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 20, 20, 20], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.max_pool3d(x, kernel_size=2, stride=[]) @@ -341,15 +372,20 @@ def MaxPool3dEmptyStrideStaticModule_basic(module, tu: TestUtils): class MaxPool3dStaticModule(torch.nn.Module): def __init__(self): super().__init__() - self.mp3d = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], - stride=[2, 2, 2], - padding=[1, 1, 1], - dilation=[1, 1, 1]) + self.mp3d = torch.nn.MaxPool3d( + kernel_size=[3, 3, 3], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + ) + @export - @annotate_args([ - None, - ([1, 64, 112, 112, 112], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 64, 112, 112, 112], torch.float32, True), + ] + ) def forward(self, x): return self.mp3d(x) @@ -358,20 +394,25 @@ class MaxPool3dStaticModule(torch.nn.Module): def MaxPool3dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 64, 112, 112, 112)) + class MaxPool3dStaticCeilModeTrueModule(torch.nn.Module): def __init__(self): super().__init__() - self.mp3d = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], - stride=[2, 2, 2], - padding=[1, 1, 1], - dilation=[1, 1, 1], - ceil_mode=True) + self.mp3d = torch.nn.MaxPool3d( + kernel_size=[3, 3, 3], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + ceil_mode=True, + ) @export - @annotate_args([ - None, - ([1, 64, 112, 112, 112], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 64, 112, 112, 112], torch.float32, True), + ] + ) def forward(self, x): return self.mp3d(x) @@ -384,19 +425,25 @@ def MaxPool3dStaticCeilModeTrueModule_basic(module, tu: TestUtils): class MaxPool3dCeilModeTrueModule(torch.nn.Module): def __init__(self): super().__init__() - self.mp3d = torch.nn.MaxPool3d(kernel_size=[6, 8, 8], - stride=[2, 2, 2], - padding=[3, 4, 4], - dilation=2, - ceil_mode=True) + self.mp3d = torch.nn.MaxPool3d( + kernel_size=[6, 8, 8], + stride=[2, 2, 2], + padding=[3, 4, 4], + dilation=2, + ceil_mode=True, + ) + @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.mp3d(x) + @register_test_case(module_factory=lambda: MaxPool3dCeilModeTrueModule()) def MaxPool3dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, 20, low=0.5, high=1.0)) @@ -406,21 +453,20 @@ def MaxPool3dCeilModeTrueModule_basic(module, tu: TestUtils): class MaxPool2dWithIndicesModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[2, 2], - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1]) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[2, 2], stride=[1, 1], padding=[0, 0], dilation=[1, 1] + ) @register_test_case(module_factory=lambda: MaxPool2dWithIndicesModule()) @@ -429,165 +475,158 @@ def MaxPool2dWithIndicesModule_basic(module, tu: TestUtils): class MaxPool2dWithIndicesFullSizeKernelModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[4, 4], - stride=1, - padding=0, - dilation=1) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[4, 4], stride=1, padding=0, dilation=1 + ) -@register_test_case( - module_factory=lambda: MaxPool2dWithIndicesFullSizeKernelModule()) +@register_test_case(module_factory=lambda: MaxPool2dWithIndicesFullSizeKernelModule()) def MaxPool2dWithIndicesFullSizeKernelModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 4, low=0.5, high=1.0)) class MaxPool2dWithIndicesNonDefaultPaddingModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[4, 8], - stride=[1, 1], - padding=[2, 4], - dilation=1) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[4, 8], stride=[1, 1], padding=[2, 4], dilation=1 + ) @register_test_case( - module_factory=lambda: MaxPool2dWithIndicesNonDefaultPaddingModule()) + module_factory=lambda: MaxPool2dWithIndicesNonDefaultPaddingModule() +) def MaxPool2dWithIndicesNonDefaultPaddingModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 16, 16, low=-1.5, high=1.0)) class MaxPool2dWithIndicesNonDefaultStrideModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[4, 4], - stride=[1, 2], - padding=0, - dilation=1) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[4, 4], stride=[1, 2], padding=0, dilation=1 + ) -@register_test_case( - module_factory=lambda: MaxPool2dWithIndicesNonDefaultStrideModule()) +@register_test_case(module_factory=lambda: MaxPool2dWithIndicesNonDefaultStrideModule()) def MaxPool2dWithIndicesNonDefaultStrideModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 4, 16, 80, low=0.5, high=2.0)) class MaxPool2dWithIndicesNonDefaultDilationModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[4, 4], - stride=[1, 1], - padding=0, - dilation=[2, 2]) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[4, 4], stride=[1, 1], padding=0, dilation=[2, 2] + ) @register_test_case( - module_factory=lambda: MaxPool2dWithIndicesNonDefaultDilationModule()) + module_factory=lambda: MaxPool2dWithIndicesNonDefaultDilationModule() +) def MaxPool2dWithIndicesNonDefaultDilationModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 4, 16, 80, low=0.5, high=2.0)) class MaxPool2dWithIndicesNonDefaultParamsModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[8, 4], - stride=[2, 2], - padding=[1, 2], - dilation=[2, 2]) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[8, 4], stride=[2, 2], padding=[1, 2], dilation=[2, 2] + ) -@register_test_case( - module_factory=lambda: MaxPool2dWithIndicesNonDefaultParamsModule()) +@register_test_case(module_factory=lambda: MaxPool2dWithIndicesNonDefaultParamsModule()) def MaxPool2dWithIndicesNonDefaultParamsModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 4, 16, 80, low=-0.5, high=4.0)) class MaxPool2dWithIndicesAllNegativeValuesModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[4, 8], - stride=[1, 1], - padding=[2, 4], - dilation=1) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[4, 8], stride=[1, 1], padding=[2, 4], dilation=1 + ) @register_test_case( - module_factory=lambda: MaxPool2dWithIndicesAllNegativeValuesModule()) + module_factory=lambda: MaxPool2dWithIndicesAllNegativeValuesModule() +) def MaxPool2dWithIndicesAllNegativeValuesModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 16, 16, low=-4.5, high=-1.0)) class MaxPool2dWithIndicesStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4, 16, 16], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 4, 16, 16], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[4, 8], - stride=[1, 1], - padding=[2, 4], - dilation=1) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[4, 8], stride=[1, 1], padding=[2, 4], dilation=1 + ) @register_test_case(module_factory=lambda: MaxPool2dWithIndicesStaticModule()) @@ -596,21 +635,20 @@ def MaxPool2dWithIndicesStaticModule_basic(module, tu: TestUtils): class MaxPool2dWithIndicesAllOnesModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[2, 2], - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1]) + return torch.ops.aten.max_pool2d_with_indices( + x, kernel_size=[2, 2], stride=[1, 1], padding=[0, 0], dilation=[1, 1] + ) @register_test_case(module_factory=lambda: MaxPool2dWithIndicesAllOnesModule()) @@ -619,26 +657,28 @@ def MaxPool2dWithIndicesAllOnesModule_basic(module, tu: TestUtils): class MaxPool2dWithIndicesCeilModeTrueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.max_pool2d_with_indices(x, - kernel_size=[2, 2], - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - ceil_mode=True) + return torch.ops.aten.max_pool2d_with_indices( + x, + kernel_size=[2, 2], + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + ceil_mode=True, + ) -@register_test_case( - module_factory=lambda: MaxPool2dWithIndicesCeilModeTrueModule()) +@register_test_case(module_factory=lambda: MaxPool2dWithIndicesCeilModeTrueModule()) def MaxPool2dWithIndicesCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 8, 8, low=0.5, high=1.0)) @@ -647,17 +687,18 @@ def MaxPool2dWithIndicesCeilModeTrueModule_basic(module, tu: TestUtils): class MaxPool2dWithIndicesBackwardStatic4DModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4, 7, 6], torch.float32, True), - ([2, 4, 6, 5], torch.float32, True), - ([2, 4, 7, 6], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 4, 7, 6], torch.float32, True), + ([2, 4, 6, 5], torch.float32, True), + ([2, 4, 7, 6], torch.int64, True), + ] + ) def forward(self, output, input, indices): kernel_size = [2, 2] stride = [1, 1] @@ -665,29 +706,30 @@ class MaxPool2dWithIndicesBackwardStatic4DModule(torch.nn.Module): dilation = [1, 1] ceil_mode = False return torch.ops.aten.max_pool2d_with_indices_backward( - output, input, kernel_size, stride, padding, dilation, ceil_mode, - indices) + output, input, kernel_size, stride, padding, dilation, ceil_mode, indices + ) -@register_test_case( - module_factory=lambda: MaxPool2dWithIndicesBackwardStatic4DModule()) +@register_test_case(module_factory=lambda: MaxPool2dWithIndicesBackwardStatic4DModule()) def MaxPool2dWithIndicesBackwardStatic4DModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), - tu.randint(2, 4, 7, 6, high=16)) + module.forward( + tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), tu.randint(2, 4, 7, 6, high=16) + ) class MaxPool2dWithIndicesBackwardStatic3DModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([4, 7, 6], torch.float32, True), - ([4, 6, 5], torch.float32, True), - ([4, 7, 6], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([4, 7, 6], torch.float32, True), + ([4, 6, 5], torch.float32, True), + ([4, 7, 6], torch.int64, True), + ] + ) def forward(self, output, input, indices): kernel_size = [2, 2] stride = [1, 1] @@ -695,29 +737,28 @@ class MaxPool2dWithIndicesBackwardStatic3DModule(torch.nn.Module): dilation = [1, 1] ceil_mode = False return torch.ops.aten.max_pool2d_with_indices_backward( - output, input, kernel_size, stride, padding, dilation, ceil_mode, - indices) + output, input, kernel_size, stride, padding, dilation, ceil_mode, indices + ) -@register_test_case( - module_factory=lambda: MaxPool2dWithIndicesBackwardStatic3DModule()) +@register_test_case(module_factory=lambda: MaxPool2dWithIndicesBackwardStatic3DModule()) def MaxPool2dWithIndicesBackwardStatic3DModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 7, 6), tu.rand(4, 6, 5), - tu.randint(4, 7, 6, high=16)) + module.forward(tu.rand(4, 7, 6), tu.rand(4, 6, 5), tu.randint(4, 7, 6, high=16)) class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.int64, True), + ] + ) def forward(self, output, input, indices): kernel_size = [2, 2] stride = [1, 1] @@ -725,29 +766,32 @@ class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module): dilation = [1, 1] ceil_mode = False return torch.ops.aten.max_pool2d_with_indices_backward( - output, input, kernel_size, stride, padding, dilation, ceil_mode, - indices) + output, input, kernel_size, stride, padding, dilation, ceil_mode, indices + ) @register_test_case( - module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic4DModule()) + module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic4DModule() +) def MaxPool2dWithIndicesBackwardDynamic4DModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), - tu.randint(2, 4, 7, 6, high=16)) + module.forward( + tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), tu.randint(2, 4, 7, 6, high=16) + ) class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, output, input, indices): kernel_size = [2, 2] stride = [1, 1] @@ -755,59 +799,67 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module): dilation = [1, 1] ceil_mode = False return torch.ops.aten.max_pool2d_with_indices_backward( - output, input, kernel_size, stride, padding, dilation, ceil_mode, - indices) + output, input, kernel_size, stride, padding, dilation, ceil_mode, indices + ) @register_test_case( - module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic3DModule()) + module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic3DModule() +) def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5), - tu.randint(2, 7, 6, high=16)) + module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5), tu.randint(2, 7, 6, high=16)) # ============================================================================== class AvgPool2dFloatModule(torch.nn.Module): - def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8], - stride=[2, 2], - padding=[3, 4], - ceil_mode=False, - count_include_pad=True, - divisor_override=None) + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[6, 8], + stride=[2, 2], + padding=[3, 4], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.ap2d(x) + @register_test_case(module_factory=lambda: AvgPool2dFloatModule()) def AvgPool2dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=-1)) -class AvgPool2dIntModule(torch.nn.Module): +class AvgPool2dIntModule(torch.nn.Module): def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8], - stride=[2, 2], - padding=[3, 4], - ceil_mode=False, - count_include_pad=True, - divisor_override=None) + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[6, 8], + stride=[2, 2], + padding=[3, 4], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int64, True), + ] + ) def forward(self, x): return self.ap2d(x) @@ -818,21 +870,24 @@ def AvgPool2dIntModule_basic(module, tu: TestUtils): class AvgPool2dStaticModule(torch.nn.Module): - def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8], - stride=[2, 2], - padding=[3, 4], - ceil_mode=False, - count_include_pad=True, - divisor_override=None) + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[6, 8], + stride=[2, 2], + padding=[3, 4], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) @export - @annotate_args([ - None, - ([2, 2, 10, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 2, 10, 20], torch.float32, True), + ] + ) def forward(self, x): return self.ap2d(x) @@ -843,21 +898,24 @@ def AvgPool2dStaticModule_basic(module, tu: TestUtils): class AvgPool2dDivisorOverrideModule(torch.nn.Module): - def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool2d(kernel_size=[4, 8], - stride=[2, 3], - padding=[2, 4], - ceil_mode=False, - count_include_pad=True, - divisor_override=22) + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[4, 8], + stride=[2, 3], + padding=[2, 4], + ceil_mode=False, + count_include_pad=True, + divisor_override=22, + ) @export - @annotate_args([ - None, - ([4, 4, 20, 20], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 4, 20, 20], torch.float32, True), + ] + ) def forward(self, x): return self.ap2d(x) @@ -868,118 +926,128 @@ def AvgPool2dDivisorOverrideModule_basic(module, tu: TestUtils): class AvgPool2dCeilModeTrueModule(torch.nn.Module): - def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8], - stride=[2, 2], - padding=[3, 4], - ceil_mode=False, - count_include_pad=True, - divisor_override=None) + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[6, 8], + stride=[2, 2], + padding=[3, 4], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.ap2d(x) + @register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule()) def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) -class AvgPool2dWithoutPadModule(torch.nn.Module): +class AvgPool2dWithoutPadModule(torch.nn.Module): def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8], - stride=[2, 2], - padding=[0, 0], - ceil_mode=False, - count_include_pad=False, - divisor_override=None) + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[6, 8], + stride=[2, 2], + padding=[0, 0], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.ap2d(x) + @register_test_case(module_factory=lambda: AvgPool2dWithoutPadModule()) def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) + # ============================================================================== class AvgPool1dFloatModule(torch.nn.Module): - def __init__(self): super().__init__() - self.ap1d = torch.nn.AvgPool1d(kernel_size=6, - stride=2, - padding=3, - ceil_mode=False, - count_include_pad=True) + self.ap1d = torch.nn.AvgPool1d( + kernel_size=6, stride=2, padding=3, ceil_mode=False, count_include_pad=True + ) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.ap1d(x) + @register_test_case(module_factory=lambda: AvgPool1dFloatModule()) def AvgPool1dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, low=-1)) class AvgPool1dIntModule(torch.nn.Module): - def __init__(self): super().__init__() - self.ap1d = torch.nn.AvgPool1d(kernel_size=6, - stride=2, - padding=3, - ceil_mode=False, - count_include_pad=True) + self.ap1d = torch.nn.AvgPool1d( + kernel_size=6, stride=2, padding=3, ceil_mode=False, count_include_pad=True + ) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, x): return self.ap1d(x) + @register_test_case(module_factory=lambda: AvgPool1dIntModule()) def AvgPool1dIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 4, 20, high=100)) class AvgPool1dStaticModule(torch.nn.Module): - def __init__(self): super().__init__() - self.ap1d = torch.nn.AvgPool1d(kernel_size=6, - stride=2, - padding=3, - ceil_mode=False, - count_include_pad=True) + self.ap1d = torch.nn.AvgPool1d( + kernel_size=6, stride=2, padding=3, ceil_mode=False, count_include_pad=True + ) @export - @annotate_args([ - None, - ([2, 4, 20], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([2, 4, 20], torch.int64, True), + ] + ) def forward(self, x): return self.ap1d(x) + @register_test_case(module_factory=lambda: AvgPool1dStaticModule()) def AvgPool1dStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 4, 20, high=100)) @@ -987,515 +1055,465 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils): # ============================================================================== -class AdaptiveAvgPool1dStaticLargerOutput(torch.nn.Module): +class AdaptiveAvgPool1dStaticLargerOutput(torch.nn.Module): def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=13) @export - @annotate_args([ - None, - ([5, 512, 7], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([5, 512, 7], torch.float32, True)]) + def forward(self, x): return self.aap1d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool1dStaticLargerOutput()) -def AdaptiveAvgPool1dStaticLargerOutput_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool1dStaticLargerOutput()) +def AdaptiveAvgPool1dStaticLargerOutput_basic(module, tu: TestUtils): module.forward(tu.rand(5, 512, 7)) -class AdaptiveAvgPool1dStaticEvenMultiple(torch.nn.Module): +class AdaptiveAvgPool1dStaticEvenMultiple(torch.nn.Module): def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export - @annotate_args([ - None, - ([5, 512, 147], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([5, 512, 147], torch.float32, True)]) + def forward(self, x): return self.aap1d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool1dStaticEvenMultiple()) -def AdaptiveAvgPool1dStaticEvenMultiple_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool1dStaticEvenMultiple()) +def AdaptiveAvgPool1dStaticEvenMultiple_basic(module, tu: TestUtils): module.forward(tu.rand(5, 512, 147)) -class AdaptiveAvgPool1dGeneralDynamic(torch.nn.Module): +class AdaptiveAvgPool1dGeneralDynamic(torch.nn.Module): def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export - @annotate_args([ - None, - ([-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.aap1d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool1dGeneralDynamic()) -def AdaptiveAvgPool1dGeneralDynamic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool1dGeneralDynamic()) +def AdaptiveAvgPool1dGeneralDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10)) + class AdaptiveAvgPool1dGeneralDynamicNoBatches(torch.nn.Module): - def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export - @annotate_args([ - None, - ([-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): return self.aap1d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool1dGeneralDynamicNoBatches()) -def AdaptiveAvgPool1dGeneralDynamicNoBatches_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool1dGeneralDynamicNoBatches()) +def AdaptiveAvgPool1dGeneralDynamicNoBatches_basic(module, tu: TestUtils): module.forward(tu.rand(512, 10)) -class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module): +class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export - @annotate_args([ - None, - ([1, 512, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 512, 7], torch.float32, True), + ] + ) def forward(self, x): return self.aap1d(x) + @register_test_case( - module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeStaticModule()) -def AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic( - module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeStaticModule() +) +def AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7)) + class AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule(torch.nn.Module): - def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.aap1d(x) + @register_test_case( - module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule()) -def AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic( - module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule() +) +def AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7)) + class AdaptiveAvgPool1dUnitOutputSizeStaticModule(torch.nn.Module): - def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1) @export - @annotate_args([ - None, - ([1, 512, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 512, 7], torch.float32, True), + ] + ) def forward(self, x): return self.aap1d(x) + @register_test_case( - module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeStaticModule()) -def AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic( - module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeStaticModule() +) +def AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7)) + class AdaptiveAvgPool1dUnitOutputSizeDynamicModule(torch.nn.Module): - def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1) @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return self.aap1d(x) + @register_test_case( - module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeDynamicModule()) -def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( - module, tu: TestUtils): + module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeDynamicModule() +) +def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 7)) + # AdaptiveAvgPool2d - + class AdaptiveAvgPool2dDynamic(torch.nn.Module): - def __init__(self): super().__init__() - self.aap2d = torch.nn.AdaptiveAvgPool2d(output_size=(7,13)) + self.aap2d = torch.nn.AdaptiveAvgPool2d(output_size=(7, 13)) @export - @annotate_args([ - None, - ([-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.aap2d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool2dDynamic()) -def AdaptiveAvgPool2dDynamic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool2dDynamic()) +def AdaptiveAvgPool2dDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) + class AdaptiveAvgPool2dDynamicNoBatch(torch.nn.Module): - def __init__(self): super().__init__() - self.aap2d = torch.nn.AdaptiveAvgPool2d(output_size=(7,13)) + self.aap2d = torch.nn.AdaptiveAvgPool2d(output_size=(7, 13)) @export - @annotate_args([ - None, - ([-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.aap2d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool2dDynamicNoBatch()) -def AdaptiveAvgPool2dDynamicNoBatch_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool2dDynamicNoBatch()) +def AdaptiveAvgPool2dDynamicNoBatch_basic(module, tu: TestUtils): module.forward(tu.rand(512, 10, 16)) + # AdaptiveAvgPool3d + class AdaptiveAvgPool3dDynamic(torch.nn.Module): - def __init__(self): super().__init__() - self.aap3d = torch.nn.AdaptiveAvgPool3d(output_size=(7,13,15)) + self.aap3d = torch.nn.AdaptiveAvgPool3d(output_size=(7, 13, 15)) @export - @annotate_args([ - None, - ([-1,-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.aap3d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool3dDynamic()) -def AdaptiveAvgPool3dDynamic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool3dDynamic()) +def AdaptiveAvgPool3dDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16, 17)) + class AdaptiveAvgPool3dDynamicNoBatch(torch.nn.Module): - def __init__(self): super().__init__() - self.aap3d = torch.nn.AdaptiveAvgPool3d(output_size=(7,13,15)) + self.aap3d = torch.nn.AdaptiveAvgPool3d(output_size=(7, 13, 15)) @export - @annotate_args([ - None, - ([-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.aap3d(x) -@register_test_case( - module_factory=lambda: AdaptiveAvgPool3dDynamicNoBatch()) -def AdaptiveAvgPool3dDynamicNoBatch_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveAvgPool3dDynamicNoBatch()) +def AdaptiveAvgPool3dDynamicNoBatch_basic(module, tu: TestUtils): module.forward(tu.rand(512, 10, 16, 17)) - + + # AdaptiveMaxPool1d + class AdaptiveMaxPool1dDynamic(torch.nn.Module): - def __init__(self): super().__init__() self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(7), return_indices=False) @export - @annotate_args([ - None, - ([-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.amp1d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool1dDynamic()) -def AdaptiveMaxPool1dDynamic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDynamic()) +def AdaptiveMaxPool1dDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10)) + class AdaptiveMaxPool1dDynamicNoBatch(torch.nn.Module): - def __init__(self): super().__init__() self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(7), return_indices=False) @export - @annotate_args([ - None, - ([-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): return self.amp1d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool1dDynamicNoBatch()) -def AdaptiveMaxPool1dDynamicNoBatch_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDynamicNoBatch()) +def AdaptiveMaxPool1dDynamicNoBatch_basic(module, tu: TestUtils): module.forward(tu.rand(512, 10)) + class AdaptiveMaxPool1dStatic(torch.nn.Module): - def __init__(self): super().__init__() self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(7), return_indices=False) @export - @annotate_args([ - None, - ([1, 512, 10], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([1, 512, 10], torch.float32, True)]) + def forward(self, x): return self.amp1d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool1dStatic()) -def AdaptiveMaxPool1dStatic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool1dStatic()) +def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10)) + # AdaptiveMaxPool2d -class AdaptiveMaxPool2dDynamic(torch.nn.Module): +class AdaptiveMaxPool2dDynamic(torch.nn.Module): def __init__(self): super().__init__() - self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + self.amp2d = torch.nn.AdaptiveMaxPool2d( + output_size=(7, 13), return_indices=False + ) @export - @annotate_args([ - None, - ([-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.amp2d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool2dDynamic()) -def AdaptiveMaxPool2dDynamic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool2dDynamic()) +def AdaptiveMaxPool2dDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) + class AdaptiveMaxPool2dDynamicNoBatch(torch.nn.Module): - def __init__(self): super().__init__() - self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + self.amp2d = torch.nn.AdaptiveMaxPool2d( + output_size=(7, 13), return_indices=False + ) @export - @annotate_args([ - None, - ([-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.amp2d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool2dDynamicNoBatch()) -def AdaptiveMaxPool2dDynamicNoBatch_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool2dDynamicNoBatch()) +def AdaptiveMaxPool2dDynamicNoBatch_basic(module, tu: TestUtils): module.forward(tu.rand(512, 10, 16)) -class AdaptiveMaxPool2dDynamicWithIndices(torch.nn.Module): +class AdaptiveMaxPool2dDynamicWithIndices(torch.nn.Module): def __init__(self): super().__init__() - self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + self.amp2d = torch.nn.AdaptiveMaxPool2d( + output_size=(7, 13), return_indices=True + ) @export - @annotate_args([ - None, - ([-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.amp2d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool2dDynamicWithIndices()) -def AdaptiveMaxPool2dDynamicWithIndices_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool2dDynamicWithIndices()) +def AdaptiveMaxPool2dDynamicWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) class AdaptiveMaxPool2dStatic(torch.nn.Module): - def __init__(self): super().__init__() - self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + self.amp2d = torch.nn.AdaptiveMaxPool2d( + output_size=(7, 13), return_indices=False + ) @export - @annotate_args([ - None, - ([1, 512, 10, 9], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([1, 512, 10, 9], torch.float32, True)]) + def forward(self, x): return self.amp2d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool2dStatic()) -def AdaptiveMaxPool2dStatic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool2dStatic()) +def AdaptiveMaxPool2dStatic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 9)) -class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module): +class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module): def __init__(self): super().__init__() - self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + self.amp2d = torch.nn.AdaptiveMaxPool2d( + output_size=(7, 13), return_indices=True + ) @export - @annotate_args([ - None, - ([1, 512, 10, 16], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([1, 512, 10, 16], torch.float32, True)]) + def forward(self, x): return self.amp2d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices()) -def AdaptiveMaxPool2dStaticWithIndices_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices()) +def AdaptiveMaxPool2dStaticWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) + # AdaptiveMaxPool3d + class AdaptiveMaxPool3dDynamic(torch.nn.Module): - def __init__(self): super().__init__() - self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=False) + self.amp3d = torch.nn.AdaptiveMaxPool3d( + output_size=(7, 13, 15), return_indices=False + ) @export - @annotate_args([ - None, - ([-1,-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.amp3d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool3dDynamic()) -def AdaptiveMaxPool3dDynamic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool3dDynamic()) +def AdaptiveMaxPool3dDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16, 17)) + class AdaptiveMaxPool3dDynamicNoBatch(torch.nn.Module): - def __init__(self): super().__init__() - self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=False) + self.amp3d = torch.nn.AdaptiveMaxPool3d( + output_size=(7, 13, 15), return_indices=False + ) @export - @annotate_args([ - None, - ([-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.amp3d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool3dDynamicNoBatch()) -def AdaptiveMaxPool3dDynamicNoBatch_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool3dDynamicNoBatch()) +def AdaptiveMaxPool3dDynamicNoBatch_basic(module, tu: TestUtils): module.forward(tu.rand(512, 10, 16, 17)) + class AdaptiveMaxPool3dDynamicWithIndices(torch.nn.Module): - def __init__(self): super().__init__() - self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=True) + self.amp3d = torch.nn.AdaptiveMaxPool3d( + output_size=(7, 13, 15), return_indices=True + ) @export - @annotate_args([ - None, - ([-1,-1,-1,-1,-1], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([-1, -1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): return self.amp3d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool3dDynamicWithIndices()) -def AdaptiveMaxPool3dDynamicWithIndices_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool3dDynamicWithIndices()) +def AdaptiveMaxPool3dDynamicWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16, 17)) - + class AdaptiveMaxPool3dStatic(torch.nn.Module): - def __init__(self): super().__init__() - self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=False) + self.amp3d = torch.nn.AdaptiveMaxPool3d( + output_size=(7, 13, 15), return_indices=False + ) @export - @annotate_args([ - None, - ([1, 512, 10, 9, 5], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([1, 512, 10, 9, 5], torch.float32, True)]) + def forward(self, x): return self.amp3d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool3dStatic()) -def AdaptiveMaxPool3dStatic_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool3dStatic()) +def AdaptiveMaxPool3dStatic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 9, 5)) + class AdaptiveMaxPool3dStaticWithIndices(torch.nn.Module): - def __init__(self): super().__init__() - self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=True) + self.amp3d = torch.nn.AdaptiveMaxPool3d( + output_size=(7, 13, 15), return_indices=True + ) @export - @annotate_args([ - None, - ([1, 512, 10, 16, 17], torch.float32, True) - ]) - def forward(self,x): + @annotate_args([None, ([1, 512, 10, 16, 17], torch.float32, True)]) + def forward(self, x): return self.amp3d(x) -@register_test_case( - module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices()) -def AdaptiveMaxPool3dStaticWithIndices_basic( - module, tu: TestUtils): + +@register_test_case(module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices()) +def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16, 17)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py index 47e8adffd..5114f78d5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py @@ -12,12 +12,15 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + def get_quant_model_input(): return 2 * torch.rand((1, 16)) - 1 + def get_batched_quant_model_input(): return 2 * torch.rand((1, 2, 16)) - 1 + class QuantizedNoLayer(nn.Module): def __init__(self): super().__init__() @@ -26,15 +29,18 @@ class QuantizedNoLayer(nn.Module): self.dequantize = torch.quantization.DeQuantStub() @export - @annotate_args([ - None, - ([1, 16], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 16], torch.float32, True), + ] + ) def forward(self, x): x = self.quantize(x) x = self.dequantize(x) return x + def get_quantized_no_layer(): model = QuantizedNoLayer() model.eval() @@ -46,10 +52,12 @@ def get_quantized_no_layer(): torch.quantization.convert(model, inplace=True) return model + @register_test_case(module_factory=get_quantized_no_layer) def QuantizedNoLayer_basic(module, tu: TestUtils): module.forward(get_quant_model_input()) + class QuantizedSingleLayer(nn.Module): def __init__(self): super().__init__() @@ -61,16 +69,19 @@ class QuantizedSingleLayer(nn.Module): self.dequantize = torch.quantization.DeQuantStub() @export - @annotate_args([ - None, - ([1, 16], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 16], torch.float32, True), + ] + ) def forward(self, x): x = self.quantize(x) x = self.layers(x) x = self.dequantize(x) return x + def get_quantized_single_layer(): model = QuantizedSingleLayer() model.eval() @@ -82,10 +93,12 @@ def get_quantized_single_layer(): torch.quantization.convert(model, inplace=True) return model + @register_test_case(module_factory=get_quantized_single_layer) def QuantizedSingleLayer_basic(module, tu: TestUtils): module.forward(get_quant_model_input()) + class QuantizedBatchedInputSingleLayer(nn.Module): def __init__(self): super().__init__() @@ -97,16 +110,19 @@ class QuantizedBatchedInputSingleLayer(nn.Module): self.dequantize = torch.quantization.DeQuantStub() @export - @annotate_args([ - None, - ([1, 2, 16], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 2, 16], torch.float32, True), + ] + ) def forward(self, x): x = self.quantize(x) x = self.layers(x) x = self.dequantize(x) return x + def get_batched_quantized_single_layer(): model = QuantizedBatchedInputSingleLayer() model.eval() @@ -118,10 +134,12 @@ def get_batched_quantized_single_layer(): torch.quantization.convert(model, inplace=True) return model + @register_test_case(module_factory=get_batched_quantized_single_layer) def QuantizedBatchedInputSingleLayer_basic(module, tu: TestUtils): module.forward(get_batched_quant_model_input()) + class QuantizedMLP(nn.Module): def __init__(self): super().__init__() @@ -135,16 +153,19 @@ class QuantizedMLP(nn.Module): self.dequantize = torch.quantization.DeQuantStub() @export - @annotate_args([ - None, - ([1, 16], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 16], torch.float32, True), + ] + ) def forward(self, x): x = self.quantize(x) x = self.layers(x) x = self.dequantize(x) return x + def get_quantized_mlp(): model = QuantizedMLP() model.eval() @@ -156,6 +177,7 @@ def get_quantized_mlp(): torch.quantization.convert(model, inplace=True) return model + @register_test_case(module_factory=get_quantized_mlp) def QuantizedMLP_basic(module, tu: TestUtils): module.forward(get_quant_model_input()) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 076dd4e45..9e0869dd9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class ReduceSumFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.sum(a) @@ -28,17 +31,21 @@ class ReduceSumFloatModule(torch.nn.Module): def ReduceSumFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceSumDtypeFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.sum(a, dtype=torch.float32) @@ -47,17 +54,21 @@ class ReduceSumDtypeFloatModule(torch.nn.Module): def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class ReduceSumElementTypeBoolModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.bool, True), + ] + ) def forward(self, a): return torch.sum(a) @@ -66,17 +77,21 @@ class ReduceSumElementTypeBoolModule(torch.nn.Module): def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) + # ============================================================================== + class ReduceProdFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.prod(a) @@ -85,54 +100,67 @@ class ReduceProdFloatModule(torch.nn.Module): def ReduceProdFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceProdDtypeFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.prod(a, dtype=torch.float32) - + + @register_test_case(module_factory=lambda: ReduceProdDtypeFloatModule()) def ReduceProdDtypeFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class ReduceProdElementTypeBoolModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.bool, True), + ] + ) def forward(self, a): return torch.prod(a) - + @register_test_case(module_factory=lambda: ReduceProdElementTypeBoolModule()) def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) + # ============================================================================== + class ReduceAllFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.all(a) @@ -141,17 +169,21 @@ class ReduceAllFloatModule(torch.nn.Module): def ReduceAllFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceAllIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.ops.aten.all(a) @@ -160,17 +192,21 @@ class ReduceAllIntModule(torch.nn.Module): def ReduceAllIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + # ============================================================================== + class ReduceAllBoolModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a): return torch.ops.aten.all(a) @@ -179,17 +215,21 @@ class ReduceAllBoolModule(torch.nn.Module): def ReduceAllBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + # ============================================================================== + class ReduceAnyFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.any(a) @@ -198,17 +238,21 @@ class ReduceAnyFloatModule(torch.nn.Module): def ReduceAnyFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceAnyIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.ops.aten.any(a) @@ -217,17 +261,21 @@ class ReduceAnyIntModule(torch.nn.Module): def ReduceAnyIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + # ============================================================================== + class ReduceAnyBoolModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a): return torch.ops.aten.any(a) @@ -236,17 +284,21 @@ class ReduceAnyBoolModule(torch.nn.Module): def ReduceAnyBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + # ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.sum(a, (0, 1)) @@ -255,17 +307,21 @@ class ReduceSumDimIntListFloatModule(torch.nn.Module): def ReduceSumDimIntListFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceSumDimIntListDtypeFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.sum(a, (0, 1), dtype=torch.float32) @@ -274,17 +330,21 @@ class ReduceSumDimIntListDtypeFloatModule(torch.nn.Module): def ReduceSumDimIntListDtypeFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class ReduceSumDimIntListKeepDimFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.sum(a, (1, 2), keepdim=True) @@ -293,36 +353,46 @@ class ReduceSumDimIntListKeepDimFloatModule(torch.nn.Module): def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 12, 7, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 12, 7, 7], torch.float32, True), + ] + ) def forward(self, a): return torch.sum(a, dim=(-1), keepdim=True) -@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule()) +@register_test_case( + module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule() +) def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 12, 7, 7)) + # ============================================================================== + class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.sum(a, dim=[]) @@ -331,17 +401,21 @@ class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a): return torch.sum(a, dim=(-1), keepdim=False) @@ -350,17 +424,21 @@ class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module): def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool)) + # ============================================================================== + class ReduceSumUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.sum(a) @@ -369,17 +447,21 @@ class ReduceSumUnsignedIntModule(torch.nn.Module): def ReduceSumUnsignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=0, high=100)) + # ============================================================================== + class ReduceSumSignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.sum(a) @@ -388,17 +470,21 @@ class ReduceSumSignedIntModule(torch.nn.Module): def ReduceSumSignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + # ============================================================================== + class ReduceSumDtypeIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.sum(a, dtype=torch.int64) @@ -407,17 +493,21 @@ class ReduceSumDtypeIntModule(torch.nn.Module): def ReduceSumDtypeIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) + # ============================================================================== + class ReduceProdUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.prod(a) @@ -426,36 +516,44 @@ class ReduceProdUnsignedIntModule(torch.nn.Module): def ReduceProdUnsignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=0, high=100)) + # ============================================================================== + class ReduceProdSignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.prod(a) - + @register_test_case(module_factory=lambda: ReduceProdSignedIntModule()) def ReduceProdSignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + # ============================================================================== + class ReduceProdDtypeIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.prod(a, dtype=torch.int64) @@ -464,17 +562,21 @@ class ReduceProdDtypeIntModule(torch.nn.Module): def ReduceProdDtypeIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) + # ============================================================================== + class ReduceSumDimIntListIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.sum(a, (0, 1)) @@ -483,17 +585,21 @@ class ReduceSumDimIntListIntModule(torch.nn.Module): def ReduceSumDimIntListIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100)) + # ============================================================================== + class ReduceSumDimIntListDtypeIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.sum(a, (0, 1), dtype=torch.int64) @@ -502,17 +608,21 @@ class ReduceSumDimIntListDtypeIntModule(torch.nn.Module): def ReduceSumDimIntListDtypeIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) + # ============================================================================== + class ReduceSumDimIntListKeepDimIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.sum(a, (1, 2), keepdim=True) @@ -524,15 +634,18 @@ def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils): # ============================================================================== + class ReduceProdDimIntFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.prod(a, 1, dtype=torch.float32) @@ -541,89 +654,113 @@ class ReduceProdDimIntFloatModule(torch.nn.Module): def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float32)) + # ============================================================================== + class ReduceAllDimEmpty(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.all(a, dim=0, keepdim=False) + @register_test_case(module_factory=lambda: ReduceAllDimEmpty()) def ReduceAllDimEmpty_basic(module, tu: TestUtils): module.forward(torch.tensor([])) + # ============================================================================== + class ReduceAllDimFloat(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1,-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.all(a, dim=1, keepdim=True) + @register_test_case(module_factory=lambda: ReduceAllDimFloat()) def ReduceAllDimFloat_basic(module, tu: TestUtils): - module.forward(torch.tensor([[5.0,1e-6,-5.0],[0,5.0,0]])) + module.forward(torch.tensor([[5.0, 1e-6, -5.0], [0, 5.0, 0]])) + # ============================================================================== + class ReduceAllDimInt(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1,-1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) def forward(self, a): return torch.ops.aten.all(a, dim=1, keepdim=True) + @register_test_case(module_factory=lambda: ReduceAllDimInt()) def ReduceAllDimInt_basic(module, tu: TestUtils): - module.forward(torch.tensor([[5,-5,0],[5,1e10,5]]).to(torch.int32)) + module.forward(torch.tensor([[5, -5, 0], [5, 1e10, 5]]).to(torch.int32)) + # ============================================================================== + class ReduceAllDimBool(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1,-1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a): return torch.ops.aten.all(a, dim=1, keepdim=False) + @register_test_case(module_factory=lambda: ReduceAllDimBool()) def ReduceAllDimBool_basic(module, tu: TestUtils): module.forward(torch.tensor([[True, False, True], [True, True, True]])) + # ============================================================================== + class ReduceMaxAlongDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a, 1)[0] @@ -632,17 +769,21 @@ class ReduceMaxAlongDim(torch.nn.Module): def ReduceMaxAlongDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class ReduceMinAlongDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a, 1)[0] @@ -651,15 +792,18 @@ class ReduceMinAlongDim(torch.nn.Module): def ReduceMinAlongDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + class ReduceMinAlongDimSignedInt(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a, 1) @@ -668,17 +812,21 @@ class ReduceMinAlongDimSignedInt(torch.nn.Module): def ReduceMinAlongDimSignedInt_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + # ============================================================================== + class ReduceMinAlongDimUnsignedInt(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.uint8, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a, 1) @@ -687,17 +835,21 @@ class ReduceMinAlongDimUnsignedInt(torch.nn.Module): def ReduceMinAlongDimUnsignedInt_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100).to(torch.uint8)) + # ============================================================================== + class ReduceMinAlongDimNegative(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a, 1)[0] @@ -706,17 +858,21 @@ class ReduceMinAlongDimNegative(torch.nn.Module): def ReduceMinAlongDimNegative_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64)) + # ============================================================================== + class ReduceMinKeepDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a, 1, keepdim=True)[1] @@ -725,35 +881,44 @@ class ReduceMinKeepDim(torch.nn.Module): def ReduceMinKeepDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class ReduceMinKeepDimReturnBoth(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a, 1, keepdim=True) + @register_test_case(module_factory=lambda: ReduceMinKeepDimReturnBoth()) def ReduceMinKeepDimReturnBoth_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, low=-10, high=-5)) + # ============================================================================== + class ReduceMaxAlongDimSignedInt(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a, 1) @@ -762,17 +927,21 @@ class ReduceMaxAlongDimSignedInt(torch.nn.Module): def ReduceMaxAlongDimSignedInt_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + # ============================================================================== + class ReduceMaxAlongDimUnsignedInt(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.uint8, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.uint8, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a, 1) @@ -781,17 +950,21 @@ class ReduceMaxAlongDimUnsignedInt(torch.nn.Module): def ReduceMaxAlongDimUnsignedInt_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100).to(torch.uint8)) + # ============================================================================== + class ReduceMaxAlongDimNegative(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a, 1)[0] @@ -800,17 +973,21 @@ class ReduceMaxAlongDimNegative(torch.nn.Module): def ReduceMaxAlongDimNegative_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64)) + # ============================================================================== + class ReduceMaxKeepDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a, 1, keepdim=True)[1] @@ -819,252 +996,320 @@ class ReduceMaxKeepDim(torch.nn.Module): def ReduceMaxKeepDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class ReduceMaxKeepDimReturnBoth(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a, 1, keepdim=True) + @register_test_case(module_factory=lambda: ReduceMaxKeepDimReturnBoth()) def ReduceMaxKeepDimReturnBoth_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, low=-10, high=-5)) + # ============================================================================== + class ReduceMaxAllDims(torch.nn.Module): + def __init__(self): + super().__init__() - def __init__(self): - super().__init__() + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.max(a) - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, a): - return torch.ops.aten.max(a) @register_test_case(module_factory=lambda: ReduceMaxAllDims()) def ReduceMaxAllDims_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, low=-10, high=-5)) + # ============================================================================== + class ReduceMaxNegativeDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a, -1, keepdim=True) + @register_test_case(module_factory=lambda: ReduceMaxNegativeDim()) def ReduceMaxNegativeDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceMaxFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a) + @register_test_case(module_factory=lambda: ReduceMaxFloatModule()) def ReduceMaxFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceMaxSignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a) + @register_test_case(module_factory=lambda: ReduceMaxSignedIntModule()) def ReduceMaxSignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + # ============================================================================== + class ReduceMaxUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.max(a) + @register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule()) def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100)) + # ============================================================================== + class ReduceAmaxSingleDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.amax(a, 1) + @register_test_case(module_factory=lambda: ReduceAmaxSingleDim()) def ReduceAmaxSingleDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, high=100)) + # ============================================================================== + class ReduceAmaxMultiDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.amax(a, (0, 2)) + @register_test_case(module_factory=lambda: ReduceAmaxMultiDim()) def ReduceAmaxMultiDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, high=100)) + # ============================================================================== + class ReduceAmaxOutOfOrderDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.amax(a, (2, 1, 3)) + @register_test_case(module_factory=lambda: ReduceAmaxOutOfOrderDim()) def ReduceAmaxOutOfOrderDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, 6, high=100)) + # ============================================================================== + class ReduceAmaxKeepDim(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.amax(a, (0, 2), keepdim=True) + @register_test_case(module_factory=lambda: ReduceAmaxKeepDim()) def ReduceAmaxKeepDim_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, high=100)) + # ============================================================================== + class ReduceMinFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a) + + @register_test_case(module_factory=lambda: ReduceMinFloatModule()) def ReduceMinFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceMinSignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a) + @register_test_case(module_factory=lambda: ReduceMinSignedIntModule()) def ReduceMinSignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + # ============================================================================== + class ReduceMinUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.min(a) + @register_test_case(module_factory=lambda: ReduceMinUnsignedIntModule()) def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100)) + # ============================================================================== + class ArgminModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.argmin(a) @@ -1073,18 +1318,21 @@ class ArgminModule(torch.nn.Module): def ArgminModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ArgminIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.argmin(a) @@ -1093,60 +1341,74 @@ class ArgminIntModule(torch.nn.Module): def ArgminIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-100, high=100)) + @register_test_case(module_factory=lambda: ArgminIntModule()) def ArgminIntModule_multiple_mins(module, tu: TestUtils): # To cover the special case that the minimal value occurs more than once. # The pytorch convention is here to consider the first occurence as the argmin. - module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64)) + module.forward(torch.full((3, 4), tu.randint(1).item(), dtype=torch.int64)) + # ============================================================================== + class ArgminWithDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.argmin(a, dim=1) + @register_test_case(module_factory=lambda: ArgminWithDimModule()) def ArgminModule_with_dim(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ArgminKeepDimsModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.argmin(a, 0, True) + @register_test_case(module_factory=lambda: ArgminKeepDimsModule()) def ArgminModule_keepDim(module, tu: TestUtils): module.forward(tu.rand(4, 6)) + # ============================================================================== + class ArgmaxModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.argmax(a) @@ -1155,18 +1417,21 @@ class ArgmaxModule(torch.nn.Module): def ArgmaxModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ArgmaxIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.argmax(a) @@ -1175,540 +1440,693 @@ class ArgmaxIntModule(torch.nn.Module): def ArgmaxIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-100, high=100)) + @register_test_case(module_factory=lambda: ArgmaxIntModule()) def ArgmaxIntModule_multiple_maxs(module, tu: TestUtils): # To cover the special case that the maximal value occurs more than once. # The pytorch convention is here to consider the first occurence as the argmax. - module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64)) + module.forward(torch.full((3, 4), tu.randint(1).item(), dtype=torch.int64)) + # ============================================================================== + class ArgmaxWithDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.argmax(a, dim=1) + @register_test_case(module_factory=lambda: ArgmaxWithDimModule()) def ArgmaxModule_with_dim(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ArgmaxKeepDimsModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.argmax(a, 0, True) + @register_test_case(module_factory=lambda: ArgmaxKeepDimsModule()) def ArgmaxModule_keepDim(module, tu: TestUtils): module.forward(tu.rand(4, 6)) + # ============================================================================== + class ReduceL1NormModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, dim=0, ord=1) + @register_test_case(module_factory=lambda: ReduceL1NormModule()) def ReduceL1NormModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceL1NormWithDTypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, dim=0, ord=1, dtype=torch.float64) + @register_test_case(module_factory=lambda: ReduceL1NormWithDTypeModule()) def ReduceL1NormWithDTypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float32)) + # ============================================================================== - + + class ReduceL1NormComplexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.cfloat, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.cfloat, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, dim=0, ord=1) + @register_test_case(module_factory=lambda: ReduceL1NormComplexModule()) def ReduceL1NormComplexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.cfloat)) + # ============================================================================== + class ReduceL2NormModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, dim=0) + @register_test_case(module_factory=lambda: ReduceL2NormModule()) def ReduceL2NormModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== - + + class ReduceL2NormComplexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.cdouble, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.cdouble, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, dim=0) + @register_test_case(module_factory=lambda: ReduceL2NormComplexModule()) def ReduceL2NormComplexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.cdouble)) + # ============================================================================== + class ReduceLN3NormModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, dim=0, ord=-3) + @register_test_case(module_factory=lambda: ReduceLN3NormModule()) def ReduceLN3NormModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceL3NormAllDimsModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, dim=None, ord=3) + @register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule()) def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceL3NormKeepDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, keepdim=True, ord=3) + @register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule()) def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== - + + class ReduceL3NormKeepDimComplexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.complex128, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex128, True), + ] + ) def forward(self, a): return torch.linalg.vector_norm(a, keepdim=True, ord=3) + @register_test_case(module_factory=lambda: ReduceL3NormKeepDimComplexModule()) def ReduceL3NormKeepDimComplexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.complex128)) + # ============================================================================== + class NormScalarModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = 3.0 @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.norm(a, self.p) + @register_test_case(module_factory=lambda: NormScalarModule()) def NormScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== - + + class NormScalarComplexModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = 3.0 @export - @annotate_args([ - None, - ([-1, -1, -1], torch.complex64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex64, True), + ] + ) def forward(self, a): return torch.ops.aten.norm(a, self.p) + @register_test_case(module_factory=lambda: NormScalarComplexModule()) def NormScalarComplexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.complex64)) + # ============================================================================== + class NormScalarOptDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = 3.0 @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.norm(a, self.p, dim=[0, 1], keepdim=False) + @register_test_case(module_factory=lambda: NormScalarOptDimModule()) def NormScalarOptDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class NormScalarOptDimKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = 3.0 @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.norm(a, self.p, dim=[0, 1], keepdim=True) + @register_test_case(module_factory=lambda: NormScalarOptDimKeepDimModule()) def NormScalarOptDimKeepDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== - + + class NormScalarOptDimKeepDimComplexModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = 3.0 @export - @annotate_args([ - None, - ([-1, -1, -1], torch.cfloat, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.cfloat, True), + ] + ) def forward(self, a): return torch.ops.aten.norm(a, self.p, dim=[0, 1], keepdim=True) + @register_test_case(module_factory=lambda: NormScalarOptDimKeepDimComplexModule()) def NormScalarOptDimKeepDimComplexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.cfloat)) + # ============================================================================== + class ReduceFrobeniusNormModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=False) + @register_test_case(module_factory=lambda: ReduceFrobeniusNormModule()) def ReduceFrobeniusNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=True) + @register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule()) def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== - + + class ReduceFrobeniusNormComplexModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.cdouble, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.cdouble, True), + ] + ) def forward(self, a): return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=False) + @register_test_case(module_factory=lambda: ReduceFrobeniusNormComplexModule()) def ReduceFrobeniusNormComplexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.cdouble)) + # ============================================================================== + class LinalgVectorNormModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.linalg_vector_norm(a, ord=3.0, dim=[0, 1], keepdim=False) + @register_test_case(module_factory=lambda: LinalgVectorNormModule()) def LinalgVectorNormModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5)) + # ============================================================================== + class LinalgVectorNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.linalg_vector_norm(a, ord=3.0, dim=[0, 1], keepdim=True) + @register_test_case(module_factory=lambda: LinalgVectorNormKeepDimModule()) def LinalgVectorNormKeepDimModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5)) + # ============================================================================== - + + class LinalgVectorNormComplexModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.complex128, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex128, True), + ] + ) def forward(self, a): return torch.ops.aten.linalg_vector_norm(a, ord=3.0, dim=[0, 1], keepdim=False) + @register_test_case(module_factory=lambda: LinalgVectorNormComplexModule()) def LinalgVectorNormComplexModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5).to(torch.complex128)) + # ============================================================================== + class LinalgNormModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=False) + @register_test_case(module_factory=lambda: LinalgNormModule()) def LinalgNormModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5)) + # ============================================================================== + class LinalgNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=True) + @register_test_case(module_factory=lambda: LinalgNormKeepDimModule()) def LinalgNormKeepDimModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5)) + # ============================================================================== - + + class LinalgNormKeepDimComplexModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.complex64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex64, True), + ] + ) def forward(self, a): return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=True) + @register_test_case(module_factory=lambda: LinalgNormKeepDimComplexModule()) def LinalgNormKeepDimComplexModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5).to(torch.complex64)) + # ============================================================================== + class MseLossNoReductionModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1 , -1], torch.float32, True), - ([-1 , -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.mse_loss(x, y, reduction=0) + @register_test_case(module_factory=lambda: MseLossNoReductionModule()) def MseLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4), tu.rand(2, 4)) + # ============================================================================== + class MseLossMeanReductionModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1 , -1], torch.float32, True), - ([-1 , -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.mse_loss(x, y, reduction=1) + @register_test_case(module_factory=lambda: MseLossMeanReductionModule()) def MseLossMeanReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4), tu.rand(2, 4)) + # ============================================================================== + class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1 , -1], torch.float32, True), - ([-1 , -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.mse_loss(x, y, reduction=2) -@register_test_case(module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule()) + +@register_test_case( + module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule() +) def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4), tu.rand(2, 4).to(torch.float64)) + # ============================================================================== + class CrossEntropyLossModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1 , -1], torch.float32, True), - ([-1, ], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ( + [ + -1, + ], + torch.int64, + True, + ), + ] + ) def forward(self, input, target): return torch.ops.aten.cross_entropy_loss(input, target) + @register_test_case(module_factory=lambda: CrossEntropyLossModule()) def CrossEntropyLossModule_basic(module, tu: TestUtils): module.forward(tu.rand(8, 2), tu.randint(8, high=2)) @@ -1719,68 +2137,89 @@ class CrossEntropyLossNoReductionModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1 , -1], torch.float32, True), - ([-1, ], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ( + [ + -1, + ], + torch.int64, + True, + ), + ] + ) def forward(self, input, target): return torch.ops.aten.cross_entropy_loss(input, target, reduction=0) + @register_test_case(module_factory=lambda: CrossEntropyLossNoReductionModule()) def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(8, 2), tu.randint(8, high=2)) + # ============================================================================== + class TraceModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten.trace(a) + @register_test_case(module_factory=lambda: TraceModule()) def TraceModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3)) + @register_test_case(module_factory=lambda: TraceModule()) def TraceModule_nonsquare(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + @register_test_case(module_factory=lambda: TraceModule()) def TraceModule_empty(module, tu: TestUtils): - module.forward(torch.empty(0,0)) + module.forward(torch.empty(0, 0)) + # ============================================================================== + class TraceIntModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, a): return torch.ops.aten.trace(a) + @register_test_case(module_factory=lambda: TraceIntModule()) def TraceSignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 2, low=-10, high=10)) + @register_test_case(module_factory=lambda: TraceIntModule()) def TraceUnsignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 2, low=0, high=10)) + @register_test_case(module_factory=lambda: TraceIntModule()) def TraceUnsignedIntModule_empty(module, tu: TestUtils): module.forward(tu.randint(0, 0, low=0, high=10)) - diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 8aa3e2c1f..7b569529b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -10,114 +10,136 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class ViewExpandModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([6, 4], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([6, 4], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 3, 4) + @register_test_case(module_factory=lambda: ViewExpandModule()) def ViewExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 4)) + # ============================================================================== + class ViewExpandOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([1], torch.float32, True), + ] + ) def forward(self, a): return a.view(1, 1, 1, 1, 1) + @register_test_case(module_factory=lambda: ViewExpandOnesModule()) def ViewExpandOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(1)) + # ============================================================================== + class ViewExpandOnesBeforeAndAfterModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 1, 16, 1, 1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 1, 16, 1, 1], torch.float32, True), + ] + ) def forward(self, a): return a.view(1, 2, 1, 16, 1, 1, 1, 1) + @register_test_case(module_factory=lambda: ViewExpandOnesBeforeAndAfterModule()) def ViewExpandOnesBeforeAndAfterModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 1, 16, 1, 1)) + # ============================================================================== + class ViewExpandOnesMiddleModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 2], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([3, 1, 2], torch.float32, True), + ] + ) def forward(self, a): return a.view(3, 1, 1, 1, 1, 2) + @register_test_case(module_factory=lambda: ViewExpandOnesMiddleModule()) def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) # ============================================================================== + class ViewCollapseOnesMiddleModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 1, 1, 1, 1, 2], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([3, 1, 1, 1, 1, 2], torch.float32, True), + ] + ) def forward(self, a): return a.view(3, 1, 2) + @register_test_case(module_factory=lambda: ViewCollapseOnesMiddleModule()) def ViewCollapseOnesMiddleModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 1, 1, 1, 2)) + # ============================================================================== + class ViewDynamicExpandModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, 30, 384], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, 30, 384], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 4, 5, 6, 12, 32) + @register_test_case(module_factory=lambda: ViewDynamicExpandModule()) def ViewDynamicExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 30, 384)) @@ -131,35 +153,29 @@ class SplitDimStaticModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([12], torch.float32, True)]) - + @annotate_args([None, ([12], torch.float32, True)]) def forward(self, a): return torch.ops.prims.split_dim(a, 0, 4) -@register_test_case( - module_factory=lambda: SplitDimStaticModule()) + +@register_test_case(module_factory=lambda: SplitDimStaticModule()) def SplitDimStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(12)) + class SplitDimDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True)]) - + @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, a): return torch.ops.prims.split_dim(a, 0, 3) -@register_test_case( - module_factory=lambda: SplitDimDynamicModule()) -def SplitDimDynamicModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,5)) +@register_test_case(module_factory=lambda: SplitDimDynamicModule()) +def SplitDimDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 5)) # ============================================================================== @@ -169,18 +185,15 @@ class CollapseAllDimensionsModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([2,2,2,2], torch.float32, True)]) - + @annotate_args([None, ([2, 2, 2, 2], torch.float32, True)]) def forward(self, a): return torch.ops.prims.collapse(a, 0, 3) -@register_test_case( - module_factory=lambda: CollapseAllDimensionsModule()) +@register_test_case(module_factory=lambda: CollapseAllDimensionsModule()) def CollapseAllDimensionsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2,2,2,2)) + module.forward(tu.rand(2, 2, 2, 2)) + # ============================================================================== # @@ -189,18 +202,16 @@ class CollapseRank1DynamicModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True)]) - + @annotate_args([None, ([-1], torch.float32, True)]) def forward(self, a): return torch.ops.prims.collapse(a, 0, 0) -@register_test_case( - module_factory=lambda: CollapseRank1DynamicModule()) + +@register_test_case(module_factory=lambda: CollapseRank1DynamicModule()) def CollapseRank1DynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(5)) + # ============================================================================== # class CollapseStaticModule(torch.nn.Module): @@ -208,18 +219,15 @@ class CollapseStaticModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([2,3,4], torch.float32, True)]) - + @annotate_args([None, ([2, 3, 4], torch.float32, True)]) def forward(self, a): return torch.ops.prims.collapse(a, 1, 2) -@register_test_case( - module_factory=lambda: CollapseStaticModule()) +@register_test_case(module_factory=lambda: CollapseStaticModule()) def CollapseStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2,3,4)) + module.forward(tu.rand(2, 3, 4)) + # ============================================================================== # @@ -228,572 +236,684 @@ class CollapsePartialDynamicModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1,-1,4,5], torch.float32, True)]) - + @annotate_args([None, ([-1, -1, 4, 5], torch.float32, True)]) def forward(self, a): return torch.ops.prims.collapse(a, 1, 2) -@register_test_case( - module_factory=lambda: CollapsePartialDynamicModule()) +@register_test_case(module_factory=lambda: CollapsePartialDynamicModule()) def CollapsePartialDynamicModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2,3,4,5)) + module.forward(tu.rand(2, 3, 4, 5)) + class CollapseFullDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1,-1,-1], torch.float32, True)]) - + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) def forward(self, a): - return torch.ops.prims.collapse(a, 0,1) + return torch.ops.prims.collapse(a, 0, 1) -@register_test_case( - module_factory=lambda: CollapseFullDynamicModule()) +@register_test_case(module_factory=lambda: CollapseFullDynamicModule()) def CollapseFullDynamicModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2,3,5)) - + module.forward(tu.rand(2, 3, 5)) # ============================================================================== + class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(a.size(0), a.size(1), 12, 32) + @register_test_case(module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule()) def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 384)) + # ============================================================================== + class ViewCollapseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(8) + @register_test_case(module_factory=lambda: ViewCollapseModule()) def ViewCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + # ============================================================================== + class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1, -1], torch.float32, True), - ([], torch.int64, True), - ([], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, a, b, c): return a.view(a.size(0), int(b), int(c), a.size(3), 384) + @register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule()) def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5)) + # ============================================================================== + class ViewExpandCollapseWithOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4, 8, 8], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 4, 8, 8], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 1, 1, 4, 64) + @register_test_case(module_factory=lambda: ViewExpandCollapseWithOnesModule()) def ViewExpandCollapseWithOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 8, 8)) + # ============================================================================== + class ViewExpandCollapseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 4, 8, 16, 4], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 4, 8, 16, 4], torch.float32, True), + ] + ) def forward(self, a): return a.view(8, 2, 4, 16, 2, 2) + @register_test_case(module_factory=lambda: ViewExpandCollapseModule()) def ViewExpandCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 8, 16, 4)) + # ============================================================================== + class ViewDynamicExpandCollapseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, 4, -1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, 4, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 1, 4, 64) + @register_test_case(module_factory=lambda: ViewDynamicExpandCollapseModule()) def ViewDynamicExpandCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 8, 8)) + # ============================================================================== + class ViewDynamicExpandCollapseWithParallelUnknownDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, -1, 5], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 3, -1, 5], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, -1, 6) -@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseWithParallelUnknownDimModule()) + +@register_test_case( + module_factory=lambda: ViewDynamicExpandCollapseWithParallelUnknownDimModule() +) def ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 5)) + # ============================================================================== + class ViewDynamicExpandCollapseWithAtenIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 1, a.size(1), 64) + @register_test_case(module_factory=lambda: ViewDynamicExpandCollapseWithAtenIntModule()) def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 8, 8)) + # ============================================================================== + class ViewTwoToThreeStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 2], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([3, 2], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 3) + @register_test_case(module_factory=lambda: ViewTwoToThreeStaticModule()) def ViewTwoToThreeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2)) + # ============================================================================== + class ViewTwoFiveThreeStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 5, 2], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([3, 5, 2], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 5, 3) + @register_test_case(module_factory=lambda: ViewTwoFiveThreeStaticModule()) def ViewTwoFiveThreeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5, 2)) + # ============================================================================== + class ViewOffsetTestStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 2, 2, 5, 6], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 3, 2, 2, 5, 6], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 3, 4, 6, 5) + @register_test_case(module_factory=lambda: ViewOffsetTestStaticModule()) def ViewOffsetTestStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 2, 2, 5, 6)) + # ============================================================================== + class ViewOffsetBackwardTestStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4, 5, 6], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 3, 4, 5, 6], torch.float32, True), + ] + ) def forward(self, a): return a.view(2, 3, 2, 2, 6, 5) + @register_test_case(module_factory=lambda: ViewOffsetBackwardTestStaticModule()) def ViewOffsetBackwardTestStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 5, 6)) + # ============================================================================== + class View1DFoldModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return a.view(-1) + @register_test_case(module_factory=lambda: View1DFoldModule()) def View1DFoldModule_basic(module, tu: TestUtils): module.forward(tu.rand(32)) + # ============================================================================== + class ViewCollapseInferredDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 3, 4], torch.float32, True), + ] + ) def forward(self, a): return a.view(-1, 4) + @register_test_case(module_factory=lambda: ViewCollapseInferredDimModule()) def ViewCollapseInferredDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class ViewExpandInferredDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 6], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ] + ) def forward(self, a): return a.view(3, -1, 2) + @register_test_case(module_factory=lambda: ViewExpandInferredDimModule()) def ViewExpandInferredDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 6)) + # ============================================================================== + class ViewExpandDynamicDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, -1, 128], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([1, -1, 128], torch.float32, True), + ] + ) def forward(self, a): return a.view(16, 1, 128) + @register_test_case(module_factory=lambda: ViewExpandDynamicDimModule()) def ViewExpandDynamicDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 16, 128)) + # ============================================================================== + class ViewFlattenAndExpandModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(a.size(0), a.size(1)) + @register_test_case(module_factory=lambda: ViewFlattenAndExpandModule()) def ViewFlattenAndExpandModule_basic(module, tu: TestUtils): - module.forward(tu.rand(64,128)) + module.forward(tu.rand(64, 128)) + # ============================================================================== + class ViewSizeFromOtherTensor(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, -1], torch.float32, True), - ([1, -1, 10], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([1, -1], torch.float32, True), + ([1, -1, 10], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.view(y, (torch.ops.aten.size(x, 1), 10)) + @register_test_case(module_factory=lambda: ViewSizeFromOtherTensor()) def ViewSizeFromOtherTensor_basic(module, tu: TestUtils): module.forward(tu.rand(1, 7), tu.rand(1, 7, 10)) + # ============================================================================== + class UnsafeViewExpandModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([6, 4], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([6, 4], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten._unsafe_view(a, [2, 3, 4]) + @register_test_case(module_factory=lambda: UnsafeViewExpandModule()) def UnsafeViewExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 4)) + # ============================================================================== + class UnsafeViewDynamicExpandModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, 30, 384], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, 30, 384], torch.float32, True), + ] + ) def forward(self, a): - return torch.ops.aten._unsafe_view(a,[2, 4, 5, 6, 12, 32]) + return torch.ops.aten._unsafe_view(a, [2, 4, 5, 6, 12, 32]) + @register_test_case(module_factory=lambda: UnsafeViewDynamicExpandModule()) def UnsafeViewDynamicExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 30, 384)) + # ============================================================================== + class UnsafeViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten._unsafe_view(a, [a.size(0), a.size(1), 12, 32]) -@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule()) + +@register_test_case( + module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule() +) def UnsafeViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 384)) + # ============================================================================== + class UnsafeViewCollapseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): - return torch.ops.aten._unsafe_view(a,[8]) + return torch.ops.aten._unsafe_view(a, [8]) + @register_test_case(module_factory=lambda: UnsafeViewCollapseModule()) def UnsafeViewCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + # ============================================================================== + class UnsafeViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1, -1, -1], torch.float32, True), - ([], torch.int64, True), - ([], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, a, b, c): - return torch.ops.aten._unsafe_view(a, [a.size(0), int(b), int(c), a.size(3), 384]) + return torch.ops.aten._unsafe_view( + a, [a.size(0), int(b), int(c), a.size(3), 384] + ) -@register_test_case(module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule()) + +@register_test_case( + module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule() +) def UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5)) + # ============================================================================== + class UnsafeView1DFoldModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten._unsafe_view(a, [-1]) + @register_test_case(module_factory=lambda: UnsafeView1DFoldModule()) def UnsafeView1DFoldModule_basic(module, tu: TestUtils): module.forward(tu.rand(32)) + # ============================================================================== + class ReshapeAsModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @export - @annotate_args([ - None, - ([4, 3], torch.float32, True), - ([2, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 3], torch.float32, True), + ([2, 6], torch.float32, True), + ] + ) def forward(self, a, b): return torch.ops.aten.reshape_as(a, b) + @register_test_case(module_factory=lambda: ReshapeAsModule()) def ReshapeAsModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3), tu.rand(2, 6)) + # ============================================================================== + class ReshapeExpandModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return a.reshape(12, 32) + @register_test_case(module_factory=lambda: ReshapeExpandModule()) def ReshapeExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(384)) + # ============================================================================== + class ReshapeCollapseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.reshape(a, (-1,)) + @register_test_case(module_factory=lambda: ReshapeCollapseModule()) def ReshapeCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + # ============================================================================== + class ViewNoChange1dModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return a.view(6) + @register_test_case(module_factory=lambda: ViewNoChange1dModule()) def ViewNoChange1dModule_basic(module, tu: TestUtils): module.forward(tu.rand(6)) @@ -804,14 +924,16 @@ class ViewNoChange2dModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(5, 6) + @register_test_case(module_factory=lambda: ViewNoChange2dModule()) def ViewNoChange2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 6)) @@ -822,14 +944,16 @@ class ViewNoChange3dModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(4, 5, 6) + @register_test_case(module_factory=lambda: ViewNoChange3dModule()) def ViewNoChange3dModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6)) @@ -840,139 +964,168 @@ class ViewNoChangeStaticModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4, 5, 6], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ] + ) def forward(self, a): return a.view(4, 5, 6) + @register_test_case(module_factory=lambda: ViewNoChangeStaticModule()) def ViewNoChangeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6)) + class ViewNegativeStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 128], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([1, 128], torch.float32, True), + ] + ) def forward(self, a): return a.view(-1, 128) + @register_test_case(module_factory=lambda: ViewNegativeStaticModule()) def ViewNegativeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 128)) + class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return a.view(a.size(0), 1, 1, 1) + @register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule()) def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(128)) + class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, 1, 1, 1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, 1, 1, 1], torch.float32, True), + ] + ) def forward(self, a): return a.view(a.size(0)) + @register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule()) def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(128, 1, 1, 1)) + class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return a.view(1, 1, 1, a.size(0)) + @register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule()) def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(128)) + class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([1, 1, 1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(a.size(3)) + @register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule()) def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 1, 128)) + class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return a.view(1, 1, 1, a.size(0), 1, 1, 1) -@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule()) + +@register_test_case( + module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule() +) def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(128)) + class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 1, -1, 1, 1, 1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([1, 1, 1, -1, 1, 1, 1], torch.float32, True), + ] + ) def forward(self, a): return a.view(a.size(3)) -@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule()) + +@register_test_case( + module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule() +) def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1)) + # ============================================================================== + class ReshapeAliasExpandModule(torch.nn.Module): def __init__(self): super().__init__() @@ -980,14 +1133,16 @@ class ReshapeAliasExpandModule(torch.nn.Module): self.reshape_alias = torch.ops.aten._reshape_alias @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten._reshape_alias(a, size=(12, 32), stride=(32, 1)) + @register_test_case(module_factory=lambda: ReshapeAliasExpandModule()) def ReshapeAliasExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(384)) @@ -995,112 +1150,132 @@ def ReshapeAliasExpandModule_basic(module, tu: TestUtils): # ============================================================================== + class ReshapeDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return a.view(a.size(1), a.size(0)) + @register_test_case(module_factory=lambda: ReshapeDynamicModule()) def ReshapeDynamicModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3,4)) - + module.forward(tu.rand(3, 4)) # ============================================================================== + class ReshapeAliasCollapseModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.aten._reshape_alias(a, (8,), (1,)) + @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + # ============================================================================== + class UnflattenIntStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 24, 5], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([3, 24, 5], torch.float32, True), + ] + ) def forward(self, inputs): return torch.ops.aten.unflatten(inputs, 1, [2, 4, 3]) + @register_test_case(module_factory=lambda: UnflattenIntStaticModule()) def UnflattenIntStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 24, 5)) + class UnflattenIntNegativeOneDimStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 12, 3], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([5, 12, 3], torch.float32, True), + ] + ) def forward(self, inputs): return torch.ops.aten.unflatten(inputs, -2, [2, 2, 3, 1, 1]) + @register_test_case(module_factory=lambda: UnflattenIntNegativeOneDimStaticModule()) def UnflattenIntNegativeOneDimStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) + class UnflattenIntNegativeOneSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 12, 3], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([5, 12, 3], torch.float32, True), + ] + ) def forward(self, inputs): return torch.ops.aten.unflatten(inputs, -2, [2, -1, 3, 1, 1]) + @register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule()) def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) + # ============================================================================== + class EinsumStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 2, 4], torch.float32, True), - ([5, 4, 6], torch.float32, True), - ([3, 7, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 2, 4], torch.float32, True), + ([5, 4, 6], torch.float32, True), + ([3, 7, 6], torch.float32, True), + ] + ) def forward(self, tensor1, tensor2, tensor3): - return torch.ops.aten.einsum('bqe,ked,btd->bqtk', [tensor1, tensor2, tensor3]) + return torch.ops.aten.einsum("bqe,ked,btd->bqtk", [tensor1, tensor2, tensor3]) + @register_test_case(module_factory=lambda: EinsumStaticModule()) def EinsumStaticModule_basic(module, tu: TestUtils): @@ -1112,13 +1287,16 @@ class EinsumStaticFourDimensionModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5, 6], torch.float32, True), - ([3, 7, 5, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5, 6], torch.float32, True), + ([3, 7, 5, 6], torch.float32, True), + ] + ) def forward(self, tensor1, tensor2): - return torch.ops.aten.einsum('blhd,bshd->blhs', [tensor1, tensor2]) + return torch.ops.aten.einsum("blhd,bshd->blhs", [tensor1, tensor2]) + @register_test_case(module_factory=lambda: EinsumStaticFourDimensionModule()) def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils): @@ -1130,49 +1308,62 @@ class EinsumStaticContractRhsModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([3, 4, 5], torch.float32, True), - ([4, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([4, 5], torch.float32, True), + ] + ) def forward(self, tensor1, tensor2): - return torch.ops.aten.einsum('abc,bc->a', [tensor1, tensor2]) + return torch.ops.aten.einsum("abc,bc->a", [tensor1, tensor2]) + @register_test_case(module_factory=lambda: EinsumStaticContractRhsModule()) def EinsumStaticContractRhsModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) + class EinsumStaticWithEllipsisSlicingModule(torch.nn.Module): def __init__(self): super().__init__() - + @export - @annotate_args([ - None, - ([3, 4, 6], torch.float32, True), - ([3, 6, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4, 6], torch.float32, True), + ([3, 6, 5], torch.float32, True), + ] + ) def forward(self, tensor1, tensor2): - return torch.ops.aten.einsum('...mn,...nd->...md', [tensor1, tensor2]) - + return torch.ops.aten.einsum("...mn,...nd->...md", [tensor1, tensor2]) + + @register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingModule()) def EinsumStaticWithEllipsisSlicingModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 6), tu.rand(3, 6, 5)) + class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module): def __init__(self): super().__init__() - + @export - @annotate_args([ - None, - ([2, 6, 4, 5], torch.float32, True), - ([6, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 6, 4, 5], torch.float32, True), + ([6, 5], torch.float32, True), + ] + ) def forward(self, tensor1, tensor2): # should be abnd,bd -> abn - return torch.ops.aten.einsum('...nd,...d->...n', [tensor1, tensor2]) - -@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule()) + return torch.ops.aten.einsum("...nd,...d->...n", [tensor1, tensor2]) + + +@register_test_case( + module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule() +) def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/return_types.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/return_types.py index 7cbfe45c8..b61d845af 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/return_types.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/return_types.py @@ -13,19 +13,20 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class TestMultipleTensorReturn(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float64, True), - ([-1, -1], torch.int32, True), - ([-1, -1], torch.int64, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float64, True), + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a, b, c, d, e): return a, b, c, d, e @@ -37,54 +38,56 @@ def TestMultipleTensorReturn_basic(module, tu: TestUtils): tu.rand(2, 3).to(torch.float64), tu.rand(2, 3).to(torch.int32), tu.rand(2, 3).to(torch.int64), - tu.rand(2, 3).to(torch.bool)) + tu.rand(2, 3).to(torch.bool), + ) class TestMultipleTensorAndPrimitiveTypesReturn(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int32, True), - ([-1, -1], torch.float64, True), - ([-1, -1], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float64, True), + ([-1, -1], torch.bool, True), + ] + ) def forward(self, a, b, c): d = 1 e = 2.3 return a, b, c, d, e -@register_test_case( - module_factory=lambda: TestMultipleTensorAndPrimitiveTypesReturn()) +@register_test_case(module_factory=lambda: TestMultipleTensorAndPrimitiveTypesReturn()) def TestMultipleTensorAndPrimitiveTypesReturn_basic(module, tu: TestUtils): module.forward( tu.rand(3, 4).to(torch.int32), tu.rand(2, 3).to(torch.float64), - tu.rand(2, 3).to(torch.bool)) + tu.rand(2, 3).to(torch.bool), + ) + class TestF16Return(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float16, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float16, True), + ] + ) def forward(self, a): return a -@register_test_case( - module_factory=lambda: TestF16Return()) +@register_test_case(module_factory=lambda: TestF16Return()) def TestF16Return_basic(module, tu: TestUtils): - module.forward( - tu.rand(3, 4).to(torch.float16)) + module.forward(tu.rand(3, 4).to(torch.float16)) # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index 2b8e186ff..b2d41a422 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -6,16 +6,13 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== -class RandModule(torch.nn.Module): +class RandModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1024, 512], torch.float, True) - ]) + @annotate_args([None, ([1024, 512], torch.float, True)]) def forward(self, x): size = x.size() a = torch.rand(size) @@ -26,34 +23,41 @@ class RandModule(torch.nn.Module): def RandModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 512)) + # ============================================================================== -class UniformModule(torch.nn.Module): +class UniformModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x, y, z): a = torch.ops.aten.uniform_(x, 1.0, 10.0) b = torch.ops.aten.uniform_(y, -20.0, -5.0) c = torch.ops.aten.uniform_(z, -15.0, 3.0) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - torch.flatten(torch.std(c)) - ]) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - torch.flatten(torch.mean(c)) - ]) + std = torch.cat( + [ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + torch.flatten(torch.std(c)), + ] + ) + mean = torch.cat( + [ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + torch.flatten(torch.mean(c)), + ] + ) return std, mean @@ -62,36 +66,44 @@ def UniformModule_basic(module, tu: TestUtils): module.forward( tu.rand(256, 512, 12).double(), tu.rand(512, 1024, 12).double(), - tu.rand(512, 256, 12).double()) + tu.rand(512, 256, 12).double(), + ) + # ============================================================================== -class UniformStaticShapeModule(torch.nn.Module): +class UniformStaticShapeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([256, 512, 12], torch.float64, True), - ([512, 1024, 12], torch.float64, True), - ([512, 256, 12], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([256, 512, 12], torch.float64, True), + ([512, 1024, 12], torch.float64, True), + ([512, 256, 12], torch.float64, True), + ] + ) def forward(self, x, y, z): a = torch.ops.aten.uniform_(x, 1.0, 10.0) b = torch.ops.aten.uniform_(y, -20.0, -5.0) c = torch.ops.aten.uniform_(z, -15.0, 3.0) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - torch.flatten(torch.std(c)) - ]) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - torch.flatten(torch.mean(c)) - ]) + std = torch.cat( + [ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + torch.flatten(torch.std(c)), + ] + ) + mean = torch.cat( + [ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + torch.flatten(torch.mean(c)), + ] + ) return std, mean @@ -100,12 +112,14 @@ def UniformStaticShapeModule_basic(module, tu: TestUtils): module.forward( tu.rand(256, 512, 12).double(), tu.rand(512, 1024, 12).double(), - tu.rand(512, 256, 12).double()) + tu.rand(512, 256, 12).double(), + ) + # ============================================================================== -class UniformNoCorrelationModule(torch.nn.Module): +class UniformNoCorrelationModule(torch.nn.Module): def __init__(self): super().__init__() @@ -119,10 +133,12 @@ class UniformNoCorrelationModule(torch.nn.Module): return cov[0, 1] / torch.sqrt(cov[0, 0] * cov[1, 1]) @export - @annotate_args([ - None, - ([1000], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([1000], torch.float64, True), + ] + ) def forward(self, x): # Correlation of two independent uniforms a = torch.ops.aten.uniform(x) @@ -145,73 +161,84 @@ class UniformNoCorrelationModule(torch.nn.Module): # than `atol + rtol * correlation = 1E-6`, which is too strict. # Instead, the correlations are explicitly required to be less than # 0.001. - return torch.where(torch.abs(corr_a_b) < 0.001, 1, 2), \ - torch.where(torch.abs(corr_major) < 0.001, 1, 2), \ - torch.where(torch.abs(corr_minor) < 0.001, 1, 2) + return ( + torch.where(torch.abs(corr_a_b) < 0.001, 1, 2), + torch.where(torch.abs(corr_major) < 0.001, 1, 2), + torch.where(torch.abs(corr_minor) < 0.001, 1, 2), + ) @register_test_case(module_factory=lambda: UniformNoCorrelationModule()) def UniformNoCorrelationModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(1000).double()) + module.forward(tu.rand(1000).double()) + # ============================================================================== + class ExponentialModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): a = torch.ops.aten.exponential(x, 3.0) mean = torch.mean(a) std = torch.std(a) - return mean, std + return mean, std @register_test_case(module_factory=lambda: ExponentialModule()) def ExponentialModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(512, 512, 16).double()) + module.forward(tu.rand(512, 512, 16).double()) + # ============================================================================== + class BernoulliModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): a = torch.bernoulli(x) mean = torch.mean(a) std = torch.std(a) - return mean, std + return mean, std @register_test_case(module_factory=lambda: BernoulliModule()) def BernoulliModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(512, 512, 16).double()) + module.forward(tu.rand(512, 512, 16).double()) + # ============================================================================== + class BernoulliZerosModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x): return torch.bernoulli(x) @@ -220,17 +247,21 @@ class BernoulliZerosModule(torch.nn.Module): def BernoulliZerosModule_basic(module, tu: TestUtils): module.forward(torch.zeros(4, 8).double()) + # ============================================================================== + class BernoulliOnesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x): return torch.bernoulli(x) @@ -239,106 +270,124 @@ class BernoulliOnesModule(torch.nn.Module): def BernoulliOnesModule_basic(module, tu: TestUtils): module.forward(torch.ones(4, 8).double()) + # ============================================================================== + class BernoulliFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x, y): a = torch.ops.aten.bernoulli_(x, 0.4) b = torch.ops.aten.bernoulli_(y, 0.7) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - ]) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - ]) - return mean, std + mean = torch.cat( + [ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + ] + ) + std = torch.cat( + [ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + ] + ) + return mean, std @register_test_case(module_factory=lambda: BernoulliFloatModule()) def BernoulliFloatModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(512, 512, 16).double(), - tu.rand(512, 512, 16).double()) + module.forward(tu.rand(512, 512, 16).double(), tu.rand(512, 512, 16).double()) + # ============================================================================== + class BernoulliTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x, px): a = torch.ops.aten.bernoulli_(x, px) mean = torch.mean(a) std = torch.std(a) - return mean, std + return mean, std @register_test_case(module_factory=lambda: BernoulliTensorModule()) def BernoulliTensorModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(512, 512).double(), - tu.rand(512, 512).double()) + module.forward(tu.rand(512, 512).double(), tu.rand(512, 512).double()) + # ============================================================================== + class BernoulliPModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x, y): a = torch.ops.aten.bernoulli(x, 0.4) b = torch.ops.aten.bernoulli(y, 0.7) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - ]) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - ]) - return mean, std + mean = torch.cat( + [ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + ] + ) + std = torch.cat( + [ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + ] + ) + return mean, std @register_test_case(module_factory=lambda: BernoulliPModule()) def BernoulliPModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(512, 512, 16).double(), - tu.rand(512, 512, 16).double()) + module.forward(tu.rand(512, 512, 16).double(), tu.rand(512, 512, 16).double()) + # ============================================================================== + class RandLikeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x): a = torch.ops.aten.rand_like(x) mean = torch.mean(a) @@ -349,17 +398,21 @@ class RandLikeModule(torch.nn.Module): def RandLikeModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1024).double()) + # ============================================================================== + class RandLikeDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x): a = torch.ops.aten.rand_like(x, dtype=torch.float32) mean = torch.mean(a) @@ -370,16 +423,20 @@ class RandLikeDtypeModule(torch.nn.Module): def RandLikeDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1024).double()) + # ============================================================================== + class RandIntLowModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = torch.ops.aten.randint(low=1, high=1000, size=[1024, 1024]) mean = torch.mean(a.to(torch.float32)) @@ -390,18 +447,24 @@ class RandIntLowModule(torch.nn.Module): def RandIntLowModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class RandIntLowDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): - a = torch.ops.aten.randint(low=1, high=1000, size=[128, 256, 512], dtype=torch.float64) + a = torch.ops.aten.randint( + low=1, high=1000, size=[128, 256, 512], dtype=torch.float64 + ) mean = torch.mean(a) return mean @@ -410,16 +473,20 @@ class RandIntLowDtypeModule(torch.nn.Module): def RandIntLowDtypeModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class RandIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = torch.ops.aten.randint(high=1000, size=[1024, 1024]) mean = torch.mean(a.to(torch.float32)) @@ -430,16 +497,20 @@ class RandIntModule(torch.nn.Module): def RandIntModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class RandIntDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], dtype=torch.float64) mean = torch.mean(a.to(torch.float32)) @@ -450,16 +521,20 @@ class RandIntDtypeModule(torch.nn.Module): def RandIntDtypeModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class RandIntPinMemoryModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], pin_memory=False) mean = torch.mean(a.to(torch.float32)) @@ -470,17 +545,20 @@ class RandIntPinMemoryModule(torch.nn.Module): def RandIntPinMemoryModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== -class RandnModule(torch.nn.Module): +class RandnModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = torch.ops.aten.randn([4, 512, 1024]) std = torch.std(a.to(dtype=torch.float64)) @@ -496,18 +574,19 @@ def RandnModule_basic(module, tu: TestUtils): class RandnDtypeDeviceModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): - a = torch.ops.aten.randn([4, 512, 1024], - dtype=torch.float64, - device=torch.device("cpu")) + a = torch.ops.aten.randn( + [4, 512, 1024], dtype=torch.float64, device=torch.device("cpu") + ) std = torch.std(a) return std @@ -521,14 +600,15 @@ def RandnDtypeDeviceModule_basic(module, tu: TestUtils): class RandnGeneratorModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = torch.ops.aten.randn([4, 512, 1024], generator=None) std = torch.std(a.to(dtype=torch.float64)) @@ -544,14 +624,15 @@ def RandnGeneratorModule_basic(module, tu: TestUtils): class RandnGeneratorF64Module(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64) std = torch.std(a) @@ -571,10 +652,12 @@ class RandnLikeModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): a = torch.ops.aten.randn_like(x) std = torch.std(a) @@ -585,17 +668,21 @@ class RandnLikeModule(torch.nn.Module): def RandnLikeModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 512, 1024).double()) + # ============================================================================== + class RandnLikeDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x): a = torch.ops.aten.randn_like(x, dtype=torch.float32) std = torch.std(a) @@ -605,17 +692,22 @@ class RandnLikeDtypeModule(torch.nn.Module): @register_test_case(module_factory=lambda: RandnLikeDtypeModule()) def RandnLikeDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(256, 1024).double()) + + # ============================================================================== + class NormalFunctionalModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) def forward(self, x): a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0) mean = torch.mean(a) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 51b9fb993..5576e850a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -13,16 +13,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class AddIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return int(lhs) + int(rhs) @@ -36,16 +37,17 @@ def AddIntModule_basic(module, tu: TestUtils): class SubIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return int(lhs) - int(rhs) @@ -59,16 +61,17 @@ def SubIntModule_basic(module, tu: TestUtils): class SubFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ] + ) def forward(self, lhs, rhs): return float(lhs) - float(rhs) @@ -80,17 +83,19 @@ def SubFloatModule_basic(module, tu: TestUtils): # ============================================================================== -class MulFloatModule(torch.nn.Module): +class MulFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ] + ) def forward(self, lhs, rhs): return float(lhs) * float(rhs) @@ -104,16 +109,17 @@ def MulFloatModule_basic(module, tu: TestUtils): class MulIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return int(lhs) * int(rhs) @@ -127,16 +133,17 @@ def MulIntModule_basic(module, tu: TestUtils): class DivIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): # Cast the result to float to make e2e test baseline result to be a float. # Without the cast, baseline result is a Tensor which is unexpected. @@ -152,16 +159,17 @@ def DivIntModule_basic(module, tu: TestUtils): class DivFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ] + ) def forward(self, lhs, rhs): return float(lhs) / float(rhs) @@ -175,16 +183,17 @@ def DivFloatModule_basic(module, tu: TestUtils): class CeilFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ] + ) def forward(self, lhs, rhs): sub = float(lhs) - float(rhs) # Cast the result to int to make e2e test baseline result to be an int. @@ -204,15 +213,16 @@ def CeilFloatModule_basic(module, tu: TestUtils): class SqrtIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, a): return float(torch.ops.aten.sqrt(int(a))) @@ -223,14 +233,15 @@ def SqrtIntModule_basic(module, tu: TestUtils): class SqrtIntConstantModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return float(torch.ops.aten.sqrt(5)) @@ -244,15 +255,16 @@ def SqrtIntConstantModule_basic(module, tu: TestUtils): class BoolFloatFalseModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ] + ) def forward(self, a): sub = float(a) - float(a) return bool(torch.ops.aten.Bool(float(sub))) @@ -264,15 +276,16 @@ def BoolFloatFalseModule_basic(module, tu: TestUtils): class BoolFloatTrueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ] + ) def forward(self, a): return bool(torch.ops.aten.Bool(float(a))) @@ -283,14 +296,15 @@ def BoolFloatTrueModule_basic(module, tu: TestUtils): class BoolFloatConstantModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return bool(torch.ops.aten.Bool(5.0)) @@ -304,15 +318,16 @@ def BoolFloatConstantModule_basic(module, tu: TestUtils): class BoolIntFalseModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, a): sub = int(a) - int(a) return bool(torch.ops.aten.Bool(int(sub))) @@ -324,15 +339,16 @@ def BoolIntFalseModule_basic(module, tu: TestUtils): class BoolIntTrueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ] + ) def forward(self, a): return bool(torch.ops.aten.Bool(int(a))) @@ -343,14 +359,15 @@ def BoolIntTrueModule_basic(module, tu: TestUtils): class BoolIntConstantModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return bool(torch.ops.aten.Bool(5)) @@ -359,17 +376,21 @@ class BoolIntConstantModule(torch.nn.Module): def BoolIntConstantModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class AtenIntBoolOpModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.bool, True), - ]) + @annotate_args( + [ + None, + ([], torch.bool, True), + ] + ) def forward(self, x): return int(torch.ops.aten.Int(x)) @@ -384,9 +405,11 @@ class AtenIntBoolOpConstTrueModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return int(torch.ops.aten.Int(True)) @@ -401,9 +424,11 @@ class AtenIntBoolOpConstFalseModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ]) + @annotate_args( + [ + None, + ] + ) def forward(self): return int(torch.ops.aten.Int(False)) @@ -412,21 +437,25 @@ class AtenIntBoolOpConstFalseModule(torch.nn.Module): def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class AtenIntTensorByteDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.uint8, True), - ]) - + @annotate_args( + [ + None, + ([], torch.uint8, True), + ] + ) def forward(self, val): return int(val) + @register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule()) def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8)) @@ -434,57 +463,68 @@ def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils): # ============================================================================== + class AtenIntTensorCharDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int8, True), - ]) - + @annotate_args( + [ + None, + ([], torch.int8, True), + ] + ) def forward(self, val): return int(val) + @register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule()) def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + # ============================================================================== + class AtenItemIntOpModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int8, True), - ]) - + @annotate_args( + [ + None, + ([], torch.int8, True), + ] + ) def forward(self, val): return int(val) + @register_test_case(module_factory=lambda: AtenItemIntOpModule()) def AtenItemIntOpModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + # ============================================================================== + class AtenItemFpOpModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float, True), - ]) - + @annotate_args( + [ + None, + ([], torch.float, True), + ] + ) def forward(self, val): return float(val) + @register_test_case(module_factory=lambda: AtenItemFpOpModule()) def AtenItemFpOpModule_basic(module, tu: TestUtils): module.forward(tu.rand(1)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index 25f73e349..3614adb3e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -13,16 +13,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class NeIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return int(lhs) != int(rhs) @@ -36,16 +37,17 @@ def NeIntModule_basic(module, tu: TestUtils): class EqIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return int(lhs) == int(rhs) @@ -59,16 +61,17 @@ def EqIntModule_basic(module, tu: TestUtils): class GtIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return int(lhs) > int(rhs) @@ -82,16 +85,17 @@ def GtIntModule_basic(module, tu: TestUtils): class GeIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return torch.ops.aten.ge(int(lhs), int(rhs)) @@ -105,16 +109,17 @@ def GeIntModule_basic(module, tu: TestUtils): class GeFloatModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ] + ) def forward(self, lhs, rhs): return float(lhs) >= float(rhs) @@ -128,16 +133,17 @@ def GeFloatModule_basic(module, tu: TestUtils): class GeFloatIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return float(lhs) >= int(rhs) @@ -151,16 +157,17 @@ def GeFloatIntModule_basic(module, tu: TestUtils): class NeFloatIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return float(lhs) != int(rhs) @@ -174,16 +181,17 @@ def NeFloatIntModule_basic(module, tu: TestUtils): class GtFloatIntModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([], torch.float64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([], torch.float64, True), + ([], torch.int64, True), + ] + ) def forward(self, lhs, rhs): return float(lhs) > int(rhs) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index 985beb4e0..cc4970573 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -13,435 +13,428 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=False, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=False, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule()) def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=False, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=False, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule()) def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) -class IndexPutImpl2DImplicitModule(torch.nn.Module): +class IndexPutImpl2DImplicitModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([10, 8], torch.float32, True), - ([1], torch.int64, True), - ([8], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([10, 8], torch.float32, True), + ([1], torch.int64, True), + ([8], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=False, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=False, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl2DImplicitModule()) +@register_test_case(module_factory=lambda: IndexPutImpl2DImplicitModule()) def IndexPutImpl2DImplicitModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 8), tu.randint(1, high=4), tu.rand(8)) -class IndexPutImpl2DNoneIndexStaticModule(torch.nn.Module): +class IndexPutImpl2DNoneIndexStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([1, 4], torch.int64, True), - ([3], torch.int64, True), - ([1, 3], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([1, 4], torch.int64, True), + ([3], torch.int64, True), + ([1, 3], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (None, index), - value, - accumulate=False, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (None, index), value, accumulate=False, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl2DNoneIndexStaticModule()) +@register_test_case(module_factory=lambda: IndexPutImpl2DNoneIndexStaticModule()) def IndexPutImpl2DNoneIndexStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(1, 4, high=3), tu.randint(3, high=3), tu.randint(1, 3, high=1)) + module.forward( + tu.randint(1, 4, high=3), tu.randint(3, high=3), tu.randint(1, 3, high=1) + ) class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=False, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=False, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule()) def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) # ============================================================================== class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=False, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=False, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule()) def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), - tu.randint(300, high=10000)) + module.forward( + tu.randint(200, high=1000), + tu.randint(300, high=100), + tu.randint(300, high=10000), + ) # ============================================================================== class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=True, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl1DFloatAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutImpl1DFloatAccumulateModule()) def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input.clone(), (index, ), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input.clone(), (index,), value, accumulate=True, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl2DFloatAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutImpl2DFloatAccumulateModule()) def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input.clone(), (index, ), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input.clone(), (index,), value, accumulate=True, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl3DFloatAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutImpl3DFloatAccumulateModule()) def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) # ============================================================================== class IndexPutImpl1DIntAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=True, unsafe=False + ) @register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule()) def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, high=100), tu.randint(10, high=10), - tu.randint(10, high=1000)) + module.forward( + tu.randint(10, high=100), tu.randint(10, high=10), tu.randint(10, high=1000) + ) # ============================================================================== class IndexPut1DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=False) + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) -@register_test_case( - module_factory=lambda: IndexPut1DFloatNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPut1DFloatNonAccumulateModule()) def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPut2DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=False) + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) -@register_test_case( - module_factory=lambda: IndexPut2DFloatNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPut2DFloatNonAccumulateModule()) def IndexPut2DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPut3DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=False) + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) -@register_test_case( - module_factory=lambda: IndexPut3DFloatNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPut3DFloatNonAccumulateModule()) def IndexPut3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) # ============================================================================== class IndexPut1DIntNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=False) + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) @register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule()) def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), - tu.randint(300, high=10000)) + module.forward( + tu.randint(200, high=1000), + tu.randint(300, high=100), + tu.randint(300, high=10000), + ) class IndexPut2DIntNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=False) + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) @register_test_case(module_factory=lambda: IndexPut2DIntNonAccumulateModule()) def IndexPut2DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, high=1000)) + module.forward( + tu.randint(10, 8, high=1000), tu.randint(5, high=4), tu.randint(5, 8, high=1000) + ) class IndexPut3DIntNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=False) + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) @register_test_case(module_factory=lambda: IndexPut3DIntNonAccumulateModule()) def IndexPut3DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, 6, high=1000)) + module.forward( + tu.randint(10, 8, 6, high=1000), + tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000), + ) # ============================================================================== class IndexPut1DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=True) + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) @register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule()) @@ -450,21 +443,20 @@ def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils): class IndexPut2DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=True) + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) @register_test_case(module_factory=lambda: IndexPut2DFloatAccumulateModule()) @@ -473,102 +465,102 @@ def IndexPut2DFloatAccumulateModule_basic(module, tu: TestUtils): class IndexPut3DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=True) + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) @register_test_case(module_factory=lambda: IndexPut3DFloatAccumulateModule()) def IndexPut3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) # ============================================================================== class IndexPut1DIntAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=True) + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) @register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule()) def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, high=100), tu.randint(10, high=10), - tu.randint(10, high=1000)) + module.forward( + tu.randint(10, high=100), tu.randint(10, high=10), tu.randint(10, high=1000) + ) class IndexPut2DIntAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=True) + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) @register_test_case(module_factory=lambda: IndexPut2DIntAccumulateModule()) def IndexPut2DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, high=1000)) + module.forward( + tu.randint(10, 8, high=1000), tu.randint(5, high=4), tu.randint(5, 8, high=1000) + ) class IndexPut3DIntAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, (index, ), - value, - accumulate=True) + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) @register_test_case(module_factory=lambda: IndexPut3DIntAccumulateModule()) def IndexPut3DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, 6, high=1000)) + module.forward( + tu.randint(10, 8, 6, high=1000), + tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000), + ) # ============================================================================== @@ -576,296 +568,300 @@ def IndexPut3DIntAccumulateModule_basic(module, tu: TestUtils): class IndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, [index], - value, - accumulate=False) + return torch.ops.aten.index_put(input, [index], value, accumulate=False) @register_test_case( - module_factory=lambda: IndexPutHackedTwin1DFloatNonAccumulateModule()) + module_factory=lambda: IndexPutHackedTwin1DFloatNonAccumulateModule() +) def IndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPutHackedTwin2DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, [index], - value, - accumulate=False) + return torch.ops.aten.index_put(input, [index], value, accumulate=False) @register_test_case( - module_factory=lambda: IndexPutHackedTwin2DFloatNonAccumulateModule()) + module_factory=lambda: IndexPutHackedTwin2DFloatNonAccumulateModule() +) def IndexPutHackedTwin2DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, [index], - value, - accumulate=False) + return torch.ops.aten.index_put(input, [index], value, accumulate=False) @register_test_case( - module_factory=lambda: IndexPutHackedTwin3DFloatNonAccumulateModule()) + module_factory=lambda: IndexPutHackedTwin3DFloatNonAccumulateModule() +) def IndexPutHackedTwin3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) # ============================================================================== class IndexPutHackedTwin1DIntNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, [index], - value, - accumulate=False) + return torch.ops.aten.index_put(input, [index], value, accumulate=False) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin1DIntNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin1DIntNonAccumulateModule()) def IndexPutHackedTwin1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), - tu.randint(300, high=10000)) + module.forward( + tu.randint(200, high=1000), + tu.randint(300, high=100), + tu.randint(300, high=10000), + ) class IndexPutHackedTwin2DIntNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, [index], - value, - accumulate=False) + return torch.ops.aten.index_put(input, [index], value, accumulate=False) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin2DIntNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin2DIntNonAccumulateModule()) def IndexPutHackedTwin2DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, high=1000)) + module.forward( + tu.randint(10, 8, high=1000), tu.randint(5, high=4), tu.randint(5, 8, high=1000) + ) class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten.index_put(input, [index], - value, - accumulate=False) + return torch.ops.aten.index_put(input, [index], value, accumulate=False) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin3DIntNonAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin3DIntNonAccumulateModule()) def IndexPutHackedTwin3DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, 6, high=1000)) + module.forward( + tu.randint(10, 8, 6, high=1000), + tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000), + ) # ============================================================================== class IndexPutHackedTwin1DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.index_put(input, [index], value, accumulate=True) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin1DFloatAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin1DFloatAccumulateModule()) def IndexPutHackedTwin1DFloatAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPutHackedTwin2DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.index_put(input, [index], value, accumulate=True) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin2DFloatAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin2DFloatAccumulateModule()) def IndexPutHackedTwin2DFloatAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.index_put(input, [index], value, accumulate=True) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin3DFloatAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin3DFloatAccumulateModule()) def IndexPutHackedTwin3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) # ============================================================================== class IndexPutHackedTwin1DIntAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.index_put(input, [index], value, accumulate=True) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin1DIntAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin1DIntAccumulateModule()) def IndexPutHackedTwin1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, high=100), tu.randint(10, high=10), - tu.randint(10, high=1000)) + module.forward( + tu.randint(10, high=100), tu.randint(10, high=10), tu.randint(10, high=1000) + ) class IndexPutHackedTwin2DIntAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.index_put(input, [index], value, accumulate=True) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin2DIntAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin2DIntAccumulateModule()) def IndexPutHackedTwin2DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, high=1000)) + module.forward( + tu.randint(10, 8, high=1000), tu.randint(5, high=4), tu.randint(5, 8, high=1000) + ) class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ([-1], torch.int64, True), - ([-1, -1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.index_put(input, [index], value, accumulate=True) -@register_test_case( - module_factory=lambda: IndexPutHackedTwin3DIntAccumulateModule()) +@register_test_case(module_factory=lambda: IndexPutHackedTwin3DIntAccumulateModule()) def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), - tu.randint(5, 8, 6, high=1000)) + module.forward( + tu.randint(10, 8, 6, high=1000), + tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000), + ) # ============================================================================== @@ -873,127 +869,132 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): class UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._unsafe_index_put(input, [index], - value, - accumulate=False) + return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False) @register_test_case( - module_factory=lambda: UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule()) + module_factory=lambda: UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule() +) def UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) # ============================================================================== -class ScatterSrcStaticModule(torch.nn.Module): +class ScatterSrcStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([10, 8, 6], torch.float32, True), - ([2, 4, 3], torch.int64, True), - ([5, 8, 6], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([10, 8, 6], torch.float32, True), + ([2, 4, 3], torch.int64, True), + ([5, 8, 6], torch.float32, True), + ] + ) def forward(self, input, index, src): return torch.ops.aten.scatter(input, 0, index, src) -@register_test_case( - module_factory=lambda: ScatterSrcStaticModule()) +@register_test_case(module_factory=lambda: ScatterSrcStaticModule()) def ScatterSrcStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + # ============================================================================== -class ScatterSrcModule(torch.nn.Module): +class ScatterSrcModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, src): return torch.ops.aten.scatter(input, 1, index, src) -@register_test_case( - module_factory=lambda: ScatterSrcModule()) +@register_test_case(module_factory=lambda: ScatterSrcModule()) def ScatterSrcModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(3, 4, 3)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(3, 4, 3)) + # ============================================================================== -class ScatterValueFloatModule(torch.nn.Module): +class ScatterValueFloatModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([], torch.float64, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.scatter(input, 2, index, float(value)) -@register_test_case( - module_factory=lambda: ScatterValueFloatModule()) +@register_test_case(module_factory=lambda: ScatterValueFloatModule()) def ScatterValueFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand().double()) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand().double()) + # ============================================================================== -class ScatterValueIntModule(torch.nn.Module): +class ScatterValueIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ([], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([], torch.int64, True), + ] + ) def forward(self, input, index, value): return torch.ops.aten.scatter(input, 0, index, int(value)) -@register_test_case( - module_factory=lambda: ScatterValueIntModule()) +@register_test_case(module_factory=lambda: ScatterValueIntModule()) def ScatterValueIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.randint(high=10)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.randint(high=10)) + # ============================================================================== + class ScatterReduceFloatModule(torch.nn.Module): include_self: bool reduce_type: str @@ -1004,69 +1005,73 @@ class ScatterReduceFloatModule(torch.nn.Module): self.reduce_type = reduce_type @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, src): - return torch.ops.aten.scatter_reduce(input, 0, index, src, self.reduce_type, include_self=self.include_self) + return torch.ops.aten.scatter_reduce( + input, 0, index, src, self.reduce_type, include_self=self.include_self + ) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("sum", False)) +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("sum", False)) def ScatterReduceFloatSumModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("sum", True)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("sum", True)) def ScatterReduceFloatSumModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("prod", False)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("prod", False)) def ScatterReduceFloatProdModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("prod", True)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("prod", True)) def ScatterReduceFloatProdModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amax", False)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("amax", False)) def ScatterReduceFloatMaxModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amax", True)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("amax", True)) def ScatterReduceFloatMaxModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amin", False)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("amin", False)) def ScatterReduceFloatMinModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amin", True)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("amin", True)) def ScatterReduceFloatMinModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("mean", False)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("mean", False)) def ScatterReduceFloatMeanModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("mean", True)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +@register_test_case(module_factory=lambda: ScatterReduceFloatModule("mean", True)) def ScatterReduceFloatMeanModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + # ============================================================================== + class ScatterReduceIntModule(torch.nn.Module): include_self: bool reduce_type: str @@ -1077,117 +1082,165 @@ class ScatterReduceIntModule(torch.nn.Module): self.reduce_type = reduce_type @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.int32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.int32, True), + ] + ) def forward(self, input, index, src): - return torch.ops.aten.scatter_reduce(input, 0, index, src, self.reduce_type, include_self=self.include_self) + return torch.ops.aten.scatter_reduce( + input, 0, index, src, self.reduce_type, include_self=self.include_self + ) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("sum", False)) +@register_test_case(module_factory=lambda: ScatterReduceIntModule("sum", False)) def ScatterReduceIntSumModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("sum", True)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("sum", True)) def ScatterReduceIntSumModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("prod", False)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("prod", False)) def ScatterReduceIntProdModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("prod", True)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("prod", True)) def ScatterReduceIntProdModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amax", False)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("amax", False)) def ScatterReduceIntMaxModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amax", True)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("amax", True)) def ScatterReduceIntMaxModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amin", False)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("amin", False)) def ScatterReduceIntMinModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amin", True)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("amin", True)) def ScatterReduceIntMinModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("mean", False)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("mean", False)) def ScatterReduceIntMeanModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("mean", True)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + + +@register_test_case(module_factory=lambda: ScatterReduceIntModule("mean", True)) def ScatterReduceIntMeanModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) + module.forward( + tu.randint(10, 8, 6, dtype=torch.int32, high=10), + tu.randint(2, 4, 3, high=4), + tu.randint(5, 8, 6, dtype=torch.int32, high=10), + ) + # ============================================================================== -class IndexPutImpl2DIndexModule(torch.nn.Module): +class IndexPutImpl2DIndexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index, ), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (index,), value, accumulate=True, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImpl2DIndexModule()) +@register_test_case(module_factory=lambda: IndexPutImpl2DIndexModule()) def IndexPutImpl2DIndexModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 7), tu.randint(2, 3, high=3), tu.rand(2, 3, 7)) + # ============================================================================== -class IndexPutImplIndexWithNoneModule(torch.nn.Module): +class IndexPutImplIndexWithNoneModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4, 5], torch.float32, True), - ([6, 1], torch.int64, True), - ([7], torch.int64, True), - ([2, 3, 6, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4, 5], torch.float32, True), + ([6, 1], torch.int64, True), + ([7], torch.int64, True), + ([2, 3, 6, 7], torch.float32, True), + ] + ) def forward(self, input, index1, index2, value): - return torch.ops.aten._index_put_impl_(input, (None, None, index1, index2), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_( + input, (None, None, index1, index2), value, accumulate=True, unsafe=False + ) -@register_test_case( - module_factory=lambda: IndexPutImplIndexWithNoneModule()) +@register_test_case(module_factory=lambda: IndexPutImplIndexWithNoneModule()) def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7)) - + module.forward( + tu.rand(2, 3, 4, 5), + tu.randint(6, 1, high=4), + tu.randint(7, high=5), + tu.rand(2, 3, 6, 7), + ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index 4ba5541d4..07f064de7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -11,125 +11,147 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class SliceModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return x[0:5:1, 1:3:1, 2:4:1] @register_test_case(module_factory=lambda: SliceModule()) def SliceModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) # ============================================================================== + class SliceStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([6, 4, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([6, 4, 7], torch.float32, True), + ] + ) def forward(self, x): return x[0:5:1, 1:3:1, 2:4:1] @register_test_case(module_factory=lambda: SliceStaticModule()) def SliceStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) # ============================================================================== + class SliceOutOfUpperBoundIndexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): # TODO: remove hacky cat tensor once refbackend supports 0 size dim - result = x[:8, :5, 8:] - cat_tensor = torch.ones((6,4,1), dtype=torch.float32) - return torch.cat((result,cat_tensor), dim=2) + result = x[:8, :5, 8:] + cat_tensor = torch.ones((6, 4, 1), dtype=torch.float32) + return torch.cat((result, cat_tensor), dim=2) @register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule()) def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) + # ============================================================================== + class SliceOutOfUpperBoundIndexStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([6, 4, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([6, 4, 7], torch.float32, True), + ] + ) def forward(self, x): # TODO: remove hacky cat tensor once refbackend supports 0 size dim - result = x[:8, :5, 8:] - cat_tensor = torch.ones((6,4,1), dtype=torch.float32) - return torch.cat((result,cat_tensor), dim=2) + result = x[:8, :5, 8:] + cat_tensor = torch.ones((6, 4, 1), dtype=torch.float32) + return torch.cat((result, cat_tensor), dim=2) @register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexStaticModule()) def SliceOutOfUpperBoundIndexStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) + # ============================================================================== + class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return x[:-8,-7:,:] + return x[:-8, -7:, :] @register_test_case(module_factory=lambda: SliceOutOfLowerBoundEndIndexModule()) def SliceOutOfLowerBoundEndIndexModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) + # ============================================================================== + class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return x[-8:3:1, 1:3:1, 2:4:1] @register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexModule()) def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) + # ============================================================================== @@ -139,20 +161,23 @@ class SliceEndSleStartModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): # TODO: remove hacky cat tensor once refbackend supports 0 size dim result = x[:, 4:3, :] - cat_tensor = torch.ones((6,1,7), dtype=torch.float32) + cat_tensor = torch.ones((6, 1, 7), dtype=torch.float32) return torch.cat((result, cat_tensor), dim=1) @register_test_case(module_factory=lambda: SliceEndSleStartModule()) def SliceEndSleStartModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) + # ============================================================================== @@ -162,108 +187,130 @@ class SliceStartEqEndModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): # TODO: remove hacky cat tensor once refbackend supports 0 size dim result = x[5:5, :, :] - cat_tensor = torch.ones((1,4,7), dtype=torch.float32) + cat_tensor = torch.ones((1, 4, 7), dtype=torch.float32) return torch.cat((result, cat_tensor), dim=0) @register_test_case(module_factory=lambda: SliceStartEqEndModule()) def SliceStartEqEndModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,7)) + module.forward(tu.rand(6, 4, 7)) + # ============================================================================== + class SliceSizeTwoStepModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return x[0:5:2, 0:3:2, 0:4:2] @register_test_case(module_factory=lambda: SliceSizeTwoStepModule()) def SliceSizeTwoStepModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10,5,17)) + module.forward(tu.rand(10, 5, 17)) + # ============================================================================== + class SliceNegIdxModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return x[:-1, -2:-1] @register_test_case(module_factory=lambda: SliceNegIdxModule()) def SliceNegIdxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3,9)) + module.forward(tu.rand(3, 9)) + # ============================================================================== + class SliceSingleIdxModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return x[0] @register_test_case(module_factory=lambda: SliceSingleIdxModule()) def SliceSingleIdxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,8)) + module.forward(tu.rand(6, 8)) + # ============================================================================== + class SliceWholeTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return x[:, :] @register_test_case(module_factory=lambda: SliceWholeTensorModule()) def SliceWholeTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,8)) + module.forward(tu.rand(6, 8)) + # ============================================================================== + class SelectIntModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x): return torch.select(x, dim=0, index=0) @@ -278,10 +325,12 @@ class SelectIntNegativeDimAndIndexStaticModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([5, 5], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([5, 5], torch.int64, True), + ] + ) def forward(self, x): return torch.select(x, dim=-1, index=-1) @@ -290,6 +339,7 @@ class SelectIntNegativeDimAndIndexStaticModule(torch.nn.Module): def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(5, 5, high=10)) + # ============================================================================== # For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1). @@ -299,168 +349,191 @@ class SliceScatterModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1) + return torch.ops.aten.slice_scatter(x, src, dim=1, start=0, end=1, step=1) + @register_test_case(module_factory=lambda: SliceScatterModule()) def SliceScatterModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(6, 1)) + class SliceScatterZeroDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1) + return torch.ops.aten.slice_scatter(x, src, dim=0, start=0, end=1, step=1) @register_test_case(module_factory=lambda: SliceScatterZeroDimModule()) def SliceScatterZeroDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(1, 8)) + class SliceScatterNegativeEndModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 3, end = -1, step = 1) + return torch.ops.aten.slice_scatter(x, src, dim=0, start=3, end=-1, step=1) @register_test_case(module_factory=lambda: SliceScatterNegativeEndModule()) def SliceScatterNegativeEndModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(2, 8)) -class SliceScatterNegativeDimModule(torch.nn.Module): +class SliceScatterNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.slice_scatter(x, - src, - dim=-2, - start=0, - end=1, - step=1) + return torch.ops.aten.slice_scatter(x, src, dim=-2, start=0, end=1, step=1) @register_test_case(module_factory=lambda: SliceScatterNegativeDimModule()) def SliceScatterNegativeDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(1, 8)) + class SliceScatterStepVariationModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2) + return torch.ops.aten.slice_scatter(x, src, dim=1, start=0, end=1, step=2) @register_test_case(module_factory=lambda: SliceScatterStepVariationModule()) def SliceScatterStepVariationModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(6, 1)) + class SliceScatterStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([6, 8], torch.float32, True), - ([6, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([6, 8], torch.float32, True), + ([6, 1], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1) + return torch.ops.aten.slice_scatter(x, src, dim=1, start=0, end=1, step=1) @register_test_case(module_factory=lambda: SliceScatterStaticModule()) def SliceScatterStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(6, 1)) + class SelectScatterModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.select_scatter(x, src, dim = 0, index = 0) + return torch.ops.aten.select_scatter(x, src, dim=0, index=0) @register_test_case(module_factory=lambda: SelectScatterModule()) def SelectScattertModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8, 5), tu.rand(8, 5)) + class SelectScatterStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([6, 8, 5], torch.float32, True), - ([6, 5], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([6, 8, 5], torch.float32, True), + ([6, 5], torch.float32, True), + ] + ) def forward(self, x, src): - return torch.ops.aten.select_scatter(x, src, dim = 1, index = 0) + return torch.ops.aten.select_scatter(x, src, dim=1, index=0) @register_test_case(module_factory=lambda: SelectScatterStaticModule()) def SelectScattertStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8, 5), tu.rand(6, 5)) + # ============================================================================== + class NarrowHorizontalTest(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.narrow(x, dim=0, start=0, length=2) @register_test_case(module_factory=lambda: NarrowHorizontalTest()) def NarrowHorizontalTest_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,3)) + module.forward(tu.rand(6, 4, 3)) + # ============================================================================== @@ -470,36 +543,43 @@ class NarrowVerticalTest(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.narrow(x, dim=1, start=0, length=2) @register_test_case(module_factory=lambda: NarrowVerticalTest()) def NarrowVerticalTest_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4,3)) + module.forward(tu.rand(6, 4, 3)) + # ============================================================================== + class NarrowHorizontalTest2(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.narrow(x, dim=0, start=0, length=2) @register_test_case(module_factory=lambda: NarrowHorizontalTest2()) def NarrowHorizontalTest2_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4)) + module.forward(tu.rand(6, 4)) + # ============================================================================== @@ -509,66 +589,72 @@ class NarrowVerticalTest2(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.narrow(x, dim=1, start=0, length=2) @register_test_case(module_factory=lambda: NarrowVerticalTest2()) def NarrowVerticalTest2_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4)) + module.forward(tu.rand(6, 4)) + # ============================================================================== + class NarrowTensorHorizontalModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, x): return torch.narrow(x, dim=1, start=torch.tensor(0), length=2) + @register_test_case(module_factory=lambda: NarrowTensorHorizontalModule()) def NarrowTensorHorizontalModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4)) + module.forward(tu.rand(6, 4)) + # ============================================================================== + class NarrowTensorVerticalModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, x): return torch.narrow(x, dim=1, start=torch.tensor(1), length=2) + @register_test_case(module_factory=lambda: NarrowTensorVerticalModule()) def NarrowTensorVerticalModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6,4)) + module.forward(tu.rand(6, 4)) + # ============================================================================== + class SliceCopy_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([10, 4, 4], torch.float32, True), - ([4, 4, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([10, 4, 4], torch.float32, True), + ([4, 4, 4], torch.float32, True), + ] + ) def forward(self, x, y): xslice = torch.ops.aten.slice(x, 0, 2, 6, 1) xslice.copy_(y) @@ -579,18 +665,22 @@ class SliceCopy_Module(torch.nn.Module): def SliceCopy_Module_basic(module, tu: TestUtils): module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4)) + # ============================================================================== + class SliceCopyNegative_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): xslice = torch.ops.aten.slice(x, 0, 2, -4, 1) xslice.copy_(y) @@ -610,11 +700,13 @@ class SliceCopyStartGreaterThanDimSize_Module(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): xslice = torch.ops.aten.slice(x, 0, 100, 10, 1) xslice.copy_(y) @@ -634,11 +726,13 @@ class SliceCopyEndGreaterThanDimSize_Module(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): xslice = torch.ops.aten.slice(x, 0, 2, 100, 1) xslice.copy_(y) @@ -658,11 +752,13 @@ class SliceCopyNonZeroDim_Module(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x, y): xslice = torch.ops.aten.slice(x, 1, 1, 3, 1) xslice.copy_(y) @@ -677,14 +773,15 @@ def SliceCopyNonZeroDim_Module_basic(module, tu: TestUtils): # ============================================================================== class PrimListUnpackNumMismatchModule(torch.nn.Module): def __init__(self): - super().__init__() - + super().__init__() @export - @annotate_args([ - None, - ([5, 4, 3, 2, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 4, 3, 2, 1], torch.float32, True), + ] + ) def forward(self, x): if len(x.shape) == 5: b0, t, c0, h0, w0 = x.shape @@ -709,33 +806,41 @@ class UnbindIntListUnpack_Module(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4], torch.float32, True), + ] + ) def forward(self, x): unbind_0, unbind_1 = torch.unbind(x, 0) return torch.ops.aten.sub(unbind_0, unbind_1) + @register_test_case(module_factory=lambda: UnbindIntListUnpack_Module()) def UnbindIntListUnpack_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class UnbindIntGetItem_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 3, 4], torch.float32, True), + ] + ) def forward(self, x): unbind = torch.unbind(x, 0) return torch.ops.aten.sub(unbind[0], unbind[1]) + @register_test_case(module_factory=lambda: UnbindIntGetItem_Module()) def UnbindIntGetItem_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) @@ -743,55 +848,61 @@ def UnbindIntGetItem_Module_basic(module, tu: TestUtils): # ============================================================================== + class SplitTensorGetItem_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 3, 4], torch.float32, True), + ] + ) def forward(self, x): splits = torch.ops.aten.split(x, 2, 0) return torch.ops.aten.sub(splits[0], splits[1]) + @register_test_case(module_factory=lambda: SplitTensorGetItem_Module()) def SplitTensorGetItem_Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 4)) + # ============================================================================== + class SplitTensorListUnpackModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, 3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, 3, 4], torch.float32, True), + ] + ) def forward(self, x): x1, x2, x3 = torch.ops.aten.split(x, 2, 0) return x1 + x2 + x3 + @register_test_case(module_factory=lambda: SplitTensorListUnpackModule()) def SplitTensorListUnpackModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 3, 4)) + # ============================================================================== class SplitTensorLastSmallerModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([8, 10, 12], torch.float32, True) - ]) + @annotate_args([None, ([8, 10, 12], torch.float32, True)]) def forward(self, x): s0, s1, s2 = torch.ops.aten.split(x, 3, dim=0) return s2 @@ -803,19 +914,16 @@ def SplitTensorLastSmallerModule_basic(module, tu: TestUtils): # will leave the last result to have 2 elements in that dimension. module.forward(tu.rand(8, 10, 12)) + # ============================================================================== class SplitTensorNegativeDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([10, 12, 6], torch.float32, True) - ]) + @annotate_args([None, ([10, 12, 6], torch.float32, True)]) def forward(self, x): s0, s1, s2 = torch.ops.aten.split(x, 2, -1) return s1 @@ -825,124 +933,145 @@ class SplitTensorNegativeDimModule(torch.nn.Module): def SplitTensorNegativeDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 12, 6)) + # ============================================================================== -class SplitWithSizesListUnpackModule(torch.nn.Module): +class SplitWithSizesListUnpackModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([10, 12], torch.float32, True) - ]) + @annotate_args([None, ([10, 12], torch.float32, True)]) def forward(self, x): s0, s1, s2 = torch.ops.aten.split_with_sizes(x, [3, 4, 5], -1) return (s0, s1, s2) + @register_test_case(module_factory=lambda: SplitWithSizesListUnpackModule()) def SplitWithSizesListUnpackModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 12)) + # ============================================================================== + class ChunkListUnpack_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 12, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 12, 2], torch.float32, True), + ] + ) def forward(self, x): chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) add = torch.ops.aten.add(chunk_0, chunk_1) sum = torch.ops.aten.add(add, chunk_2) return sum + @register_test_case(module_factory=lambda: ChunkListUnpack_Module()) def ChunkListUnpack_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 12, 2)) + # ============================================================================== + class ChunkListUnpackUneven_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([2, 13, 2], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([2, 13, 2], torch.float32, True), + ] + ) def forward(self, x): chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) return torch.ops.aten.add(chunk_0, chunk_1), chunk_2 + @register_test_case(module_factory=lambda: ChunkListUnpackUneven_Module()) def ChunkListUnpackUneven_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 13, 2)) + # ============================================================================== + class ChunkListUnpackDynamic_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) add = torch.ops.aten.add(chunk_0, chunk_1) sum = torch.ops.aten.add(add, chunk_2) return sum + @register_test_case(module_factory=lambda: ChunkListUnpackDynamic_Module()) def ChunkListUnpackDynamic_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 12, 2)) + # ============================================================================== + class ChunkListUnpackUnevenDynamic_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) return torch.ops.aten.add(chunk_0, chunk_1), chunk_2 + @register_test_case(module_factory=lambda: ChunkListUnpackUnevenDynamic_Module()) def ChunkListUnpackUnevenDynamic_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 13, 2)) + # ============================================================================== + class SplitWithSizes_Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([5, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([5, -1, -1], torch.float32, True), + ] + ) def forward(self, x): split = torch.split(x, [2, 1, 2], dim=0) return split[0], split[1], split[2] + @register_test_case(module_factory=lambda: SplitWithSizes_Module()) def SplitWithSizes_Module_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 2)) - - - diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py index 078f3483b..79f3559c3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py @@ -17,16 +17,17 @@ class SqueezeStaticModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([1, 7, 1, 3, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 7, 1, 3, 1], torch.float32, True), + ] + ) def forward(self, a): return torch.squeeze(a) -@register_test_case( - module_factory=lambda: SqueezeStaticModule()) +@register_test_case(module_factory=lambda: SqueezeStaticModule()) def SqueezeModule_static(module, tu: TestUtils): module.forward(tu.rand(1, 7, 1, 3, 1)) @@ -39,16 +40,17 @@ class SqueezeAllUnitDimModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1], torch.float32, True), + ] + ) def forward(self, a): return torch.squeeze(a) -@register_test_case( - module_factory=lambda: SqueezeAllUnitDimModule()) +@register_test_case(module_factory=lambda: SqueezeAllUnitDimModule()) def SqueezeModule_allUnitDim(module, tu: TestUtils): module.forward(tu.rand(1, 1)) @@ -61,17 +63,18 @@ class SqueezeBroadcastModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, a, b): return a * b.squeeze() -@register_test_case( - module_factory=lambda: SqueezeBroadcastModule()) +@register_test_case(module_factory=lambda: SqueezeBroadcastModule()) def SqueezeModule_broadcast(module, tu: TestUtils): module.forward(tu.rand(4, 3), tu.rand()) @@ -84,16 +87,17 @@ class SqueezeDimStaticModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([1, 7], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 7], torch.float32, True), + ] + ) def forward(self, a): return torch.squeeze(a, 0) -@register_test_case( - module_factory=lambda: SqueezeDimStaticModule()) +@register_test_case(module_factory=lambda: SqueezeDimStaticModule()) def SqueezeDimModule_static(module, tu: TestUtils): module.forward(tu.rand(1, 7)) @@ -106,16 +110,17 @@ class SqueezeDimDynamicModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, 1, 384, -1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, 1, 384, -1, 1], torch.float32, True), + ] + ) def forward(self, a): return torch.squeeze(a, 4) -@register_test_case( - module_factory=lambda: SqueezeDimDynamicModule()) +@register_test_case(module_factory=lambda: SqueezeDimDynamicModule()) def SqueezeDimModule_dynamic(module, tu: TestUtils): module.forward(tu.rand(8, 1, 384, 12, 1)) @@ -128,16 +133,17 @@ class SqueezeDimNegDimModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([1, -1, 1, 384, -1, 1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, -1, 1, 384, -1, 1], torch.float32, True), + ] + ) def forward(self, a): return torch.squeeze(a, -6) -@register_test_case( - module_factory=lambda: SqueezeDimNegDimModule()) +@register_test_case(module_factory=lambda: SqueezeDimNegDimModule()) def SqueezeDimModule_negDim(module, tu: TestUtils): module.forward(tu.rand(1, 8, 1, 384, 12, 1)) @@ -150,16 +156,17 @@ class SqueezeDimIdentityModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([4, 1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([4, 1, -1], torch.float32, True), + ] + ) def forward(self, a): return torch.squeeze(a, 0) -@register_test_case( - module_factory=lambda: SqueezeDimIdentityModule()) +@register_test_case(module_factory=lambda: SqueezeDimIdentityModule()) def SqueezeDimModule_identity(module, tu: TestUtils): module.forward(tu.rand(4, 1, 3)) @@ -172,16 +179,17 @@ class SqueezeDimUnitDimModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1], torch.float32, True), + ] + ) def forward(self, a): return torch.squeeze(a, 0) -@register_test_case( - module_factory=lambda: SqueezeDimUnitDimModule()) +@register_test_case(module_factory=lambda: SqueezeDimUnitDimModule()) def SqueezeDimModule_unitDim(module, tu: TestUtils): module.forward(tu.rand(1)) @@ -194,16 +202,17 @@ class PrimsSqueezeModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([1, 1, 2, 3, 1, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 1, 2, 3, 1, 4], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.prims.squeeze(a, dimensions=[0, 4, 1]) -@register_test_case( - module_factory=lambda: PrimsSqueezeModule()) +@register_test_case(module_factory=lambda: PrimsSqueezeModule()) def PrimsSqueezeModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 2, 3, 1, 4)) @@ -213,15 +222,16 @@ class PrimsSqueezeEmptyDimensionsModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([1, 2, 1, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 2, 1, 4], torch.float32, True), + ] + ) def forward(self, a): return torch.ops.prims.squeeze(a, dimensions=[]) -@register_test_case( - module_factory=lambda: PrimsSqueezeEmptyDimensionsModule()) +@register_test_case(module_factory=lambda: PrimsSqueezeEmptyDimensionsModule()) def PrimsSqueezeEmptyDimensionsModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 1, 4)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/stats.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/stats.py index c6398b48e..8317ac4db 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/stats.py @@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== + class MeanModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([3, 4], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x) @@ -28,17 +31,21 @@ class MeanModule(torch.nn.Module): def MeanModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class MeanDynamicSizesModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x) @@ -47,17 +54,21 @@ class MeanDynamicSizesModule(torch.nn.Module): def MeanDynamicSizesModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class MeanDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, dtype=torch.float32) @@ -66,17 +77,21 @@ class MeanDtypeModule(torch.nn.Module): def MeanDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class MeanDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, (0, 2)) @@ -85,17 +100,21 @@ class MeanDimModule(torch.nn.Module): def MeanDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 7)) + # ============================================================================== + class MeanDimDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, (0,), dtype=torch.float32) @@ -104,17 +123,21 @@ class MeanDimDtypeModule(torch.nn.Module): def MeanDimDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) + # ============================================================================== + class MeanDimKeepdimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, (1, 2), keepdim=True) @@ -123,17 +146,21 @@ class MeanDimKeepdimModule(torch.nn.Module): def MeanDimKeepdimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class MeanDimAllReduceModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, (0, 1, 2)) @@ -142,17 +169,21 @@ class MeanDimAllReduceModule(torch.nn.Module): def MeanDimAllReduceModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class MeanDimAllReduceKeepdimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, (0, 1, 2), keepdim=True) @@ -161,17 +192,21 @@ class MeanDimAllReduceKeepdimModule(torch.nn.Module): def MeanDimAllReduceKeepdimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class MeanDimNegativeModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, (-1, 1)) @@ -183,15 +218,18 @@ def MeanDimNegativeModule_basic(module, tu: TestUtils): # ============================================================================== + class MeanDimEmptyDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, dim=[]) @@ -200,17 +238,21 @@ class MeanDimEmptyDimModule(torch.nn.Module): def MeanDimEmptyDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class MeanDimNoneDimModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.mean(x, dim=None) @@ -219,74 +261,94 @@ class MeanDimNoneDimModule(torch.nn.Module): def MeanDimNoneDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class VarUnbiasedModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, unbiased=True) + @register_test_case(module_factory=lambda: VarUnbiasedModule()) def VarUnbiasedModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class VarBiasedModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, unbiased=False) + @register_test_case(module_factory=lambda: VarBiasedModule()) def VarBiasedModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class StdUnbiasedModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, unbiased=True) + @register_test_case(module_factory=lambda: StdUnbiasedModule()) def StdUnbiasedModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class StdBiasedModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, unbiased=False) + @register_test_case(module_factory=lambda: StdBiasedModule()) def StdBiasedModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) @@ -296,15 +358,16 @@ def StdBiasedModule_basic(module, tu: TestUtils): class StdDimKeepDimFalseModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=(1, 2), keepdim=False) @@ -318,15 +381,16 @@ def StdDimKeepDimFalseModule_basic(module, tu: TestUtils): class StdDimKeepDimTrueModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=(0, 1, 2), keepdim=True) @@ -340,15 +404,16 @@ def StdDimKeepDimTrueModule_basic(module, tu: TestUtils): class StdDimBiasedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=(0, 2), unbiased=False) @@ -362,15 +427,16 @@ def StdDimBiasedModule_basic(module, tu: TestUtils): class StdDimEmptyDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=[], keepdim=False) @@ -384,15 +450,16 @@ def StdDimEmptyDimModule_basic(module, tu: TestUtils): class StdDimNoneDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=None, keepdim=False) @@ -406,15 +473,16 @@ def StdDimNoneDimModule_basic(module, tu: TestUtils): class StdCorrectionModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=None, correction=2) @@ -428,15 +496,16 @@ def StdCorrectionModule_basic(module, tu: TestUtils): class StdCorrectionSingleDimReduceModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=[1], correction=1) @@ -450,20 +519,18 @@ def StdCorrectionSingleDimReduceModule_basic(module, tu: TestUtils): class StdCorrectionAllDimReduceModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.std(x, - dim=[0, 1, 2], - correction=10, - keepdim=False) + return torch.ops.aten.std(x, dim=[0, 1, 2], correction=10, keepdim=False) @register_test_case(module_factory=lambda: StdCorrectionAllDimReduceModule()) @@ -475,15 +542,16 @@ def StdCorrectionAllDimReduceModule_basic(module, tu: TestUtils): class StdCorrectionKeepDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=[0, 1], correction=None, keepdim=True) @@ -497,15 +565,16 @@ def StdCorrectionKeepDimModule_basic(module, tu: TestUtils): class StdCorrectionNoneModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=None, correction=None) @@ -519,15 +588,16 @@ def StdCorrectionNoneModule_basic(module, tu: TestUtils): class StdCorrectionEmptyDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=[], correction=2) @@ -541,15 +611,16 @@ def StdCorrectionEmptyDimModule_basic(module, tu: TestUtils): class StdCorrectionLargeInputModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.std(x, dim=[2, 3], correction=2) @@ -563,15 +634,16 @@ def StdCorrectionLargeInputModule_basic(module, tu: TestUtils): class VarDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=(0, 2), keepdim=True) @@ -585,15 +657,16 @@ def VarDimModule_basic(module, tu: TestUtils): class VarDimUnbiasedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=(0, 2), unbiased=True, keepdim=True) @@ -607,17 +680,18 @@ def VarDimUnbiasedModule_basic(module, tu: TestUtils): class VarDimBiasedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): - return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True) + return torch.ops.aten.var(x, dim=(0, 1), unbiased=False, keepdim=True) @register_test_case(module_factory=lambda: VarDimBiasedModule()) @@ -629,15 +703,16 @@ def VarDimBiasedModule_basic(module, tu: TestUtils): class VarDimSingleDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=(0,), keepdim=True) @@ -651,15 +726,16 @@ def VarDimSingleDimModule_basic(module, tu: TestUtils): class VarDimMultiDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=[0, 2], keepdim=False) @@ -673,15 +749,16 @@ def VarDimMultiDimModule_basic(module, tu: TestUtils): class VarDimAllDimReduceModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=True) @@ -695,15 +772,16 @@ def VarDimAllDimReduceModule_basic(module, tu: TestUtils): class VarDimNegativeModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=(-1, 1), keepdim=True) @@ -717,15 +795,16 @@ def VarDimNegativeModule_basic(module, tu: TestUtils): class VarDimEmptyDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=[], keepdim=False) @@ -739,15 +818,16 @@ def VarDimEmptyDimModule_basic(module, tu: TestUtils): class VarDimNoneDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=None, keepdim=False) @@ -761,15 +841,16 @@ def VarDimNoneDimModule_basic(module, tu: TestUtils): class VarCorrectionModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=None, correction=2) @@ -783,15 +864,16 @@ def VarCorrectionModule_basic(module, tu: TestUtils): class VarCorrectionSingleDimReduceModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=[1], correction=1) @@ -805,20 +887,18 @@ def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils): class VarCorrectionAllDimReduceModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): - return torch.ops.aten.var(x, - dim=[0, 1, 2], - correction=10, - keepdim=False) + return torch.ops.aten.var(x, dim=[0, 1, 2], correction=10, keepdim=False) @register_test_case(module_factory=lambda: VarCorrectionAllDimReduceModule()) @@ -830,15 +910,16 @@ def VarCorrectionAllDimReduceModule_basic(module, tu: TestUtils): class VarCorrectionKeepDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True) @@ -852,15 +933,16 @@ def VarCorrectionKeepDimModule_basic(module, tu: TestUtils): class VarCorrectionNoneModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=None, correction=None) @@ -874,15 +956,16 @@ def VarCorrectionNoneModule_basic(module, tu: TestUtils): class VarCorrectionEmptyDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=[], correction=2) @@ -896,15 +979,16 @@ def VarCorrectionEmptyDimModule_basic(module, tu: TestUtils): class VarCorrectionLargeInputModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var(x, dim=[2, 3], correction=2) @@ -918,15 +1002,16 @@ def VarCorrectionLargeInputModule_basic(module, tu: TestUtils): class VarMeanCorrectionModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var_mean(x, dim=[1, 2], correction=2, keepdim=True) @@ -940,15 +1025,16 @@ def VarMeanCorrectionModule_basic(module, tu: TestUtils): class VarMeanCorrectionNoneModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var_mean(x, dim=None, correction=None, keepdim=False) @@ -962,15 +1048,16 @@ def VarMeanCorrectionNoneModule_basic(module, tu: TestUtils): class VarMeanUnbiasedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var_mean(x) @@ -984,15 +1071,16 @@ def VarMeanUnbiasedModule_basic(module, tu: TestUtils): class VarMeanBiasedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var_mean(x, unbiased=False) @@ -1006,15 +1094,16 @@ def VarMeanBiasedModule_basic(module, tu: TestUtils): class VarMeanDimModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var_mean(x, dim=[1]) @@ -1025,15 +1114,16 @@ def VarMeanDimModule_basic(module, tu: TestUtils): class VarMeanDimBiasedModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, x): return torch.ops.aten.var_mean(x, dim=[1], unbiased=False, keepdim=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/threshold.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/threshold.py index 674f88e89..ac570ffc9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/threshold.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/threshold.py @@ -17,14 +17,16 @@ class Threshold1dIntI32Module(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ] + ) def forward(self, input): return torch.ops.aten.threshold(input, 1, 2) + @register_test_case(module_factory=lambda: Threshold1dIntI32Module()) def Threshold1dIntI32Module_basic(module, tu: TestUtils): module.forward(tu.randint(4, high=10).to(torch.int32)) @@ -35,14 +37,16 @@ class Threshold1dIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ] + ) def forward(self, input): return torch.ops.aten.threshold(input, 1, 2) + @register_test_case(module_factory=lambda: Threshold1dIntModule()) def Threshold1dIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, high=10)) @@ -53,14 +57,16 @@ class Threshold2dIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) def forward(self, input): return torch.ops.aten.threshold(input, 0.5, 2) + @register_test_case(module_factory=lambda: Threshold2dIntModule()) def Threshold2dIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=10)) @@ -71,14 +77,16 @@ class Threshold3dIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, input): return torch.ops.aten.threshold(input, 1, 2.2) + @register_test_case(module_factory=lambda: Threshold3dIntModule()) def Threshold3dIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, 6, high=10)) @@ -89,14 +97,16 @@ class Threshold1dFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) def forward(self, input): return torch.ops.aten.threshold(input, 1, 2) + @register_test_case(module_factory=lambda: Threshold1dFloatModule()) def Threshold1dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4)) @@ -107,14 +117,16 @@ class Threshold2dFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) def forward(self, input): return torch.ops.aten.threshold(input, 0.5, 2) + @register_test_case(module_factory=lambda: Threshold2dFloatModule()) def Threshold2dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5)) @@ -125,14 +137,16 @@ class Threshold3dFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, input): return torch.ops.aten.threshold(input, 1.4, 2.0) + @register_test_case(module_factory=lambda: Threshold3dFloatModule()) def Threshold3dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6)) @@ -143,15 +157,17 @@ class ThresholdBackward1dIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 1) + @register_test_case(module_factory=lambda: ThresholdBackward1dIntModule()) def ThresholdBackward1dIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, high=10), tu.randint(4, high=8)) @@ -162,15 +178,17 @@ class ThresholdBackward2dIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 0.5) + @register_test_case(module_factory=lambda: ThresholdBackward2dIntModule()) def ThresholdBackward2dIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8)) @@ -181,15 +199,17 @@ class ThresholdBackward3dIntModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 1) + @register_test_case(module_factory=lambda: ThresholdBackward3dIntModule()) def ThresholdBackward3dIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8)) @@ -200,15 +220,17 @@ class ThresholdBackward1dFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 1) + @register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule()) def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4)) @@ -219,15 +241,17 @@ class ThresholdBackward2dFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 0.5) + @register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule()) def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5), tu.rand(4, 5)) @@ -238,15 +262,17 @@ class ThresholdBackward3dFloatModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 1.4) + @register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule()) def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6)) @@ -257,15 +283,17 @@ class ThresholdBackward1dMixedModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([-1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 1) + @register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule()) def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.randint(4, high=10)) @@ -276,15 +304,17 @@ class ThresholdBackward2dMixedModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1, -1], torch.float32, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 0.5) + @register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule()) def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=20), tu.rand(4, 5)) @@ -295,15 +325,17 @@ class ThresholdBackward3dMixedModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ]) - + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ] + ) def forward(self, grad, input): return torch.ops.aten.threshold_backward(grad, input, 1.4) + @register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule()) def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), tu.randint(4, 5, 6, high=10)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 9c85eb873..5d3d085d5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -13,7 +13,6 @@ from torch_mlir_e2e_test.annotations import annotate_args, export class TypeConversionF32ToF64Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -29,7 +28,6 @@ def TypeConversionF32ToF64Module_basic(module, tu: TestUtils): class TypeConversionF64ToF32Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -45,7 +43,6 @@ def TypeConversionF64ToF32Module_basic(module, tu: TestUtils): class TypeConversionI32ToI64Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -61,7 +58,6 @@ def TypeConversionI32ToI64Module_basic(module, tu: TestUtils): class TypeConversionI64ToI32Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -77,7 +73,6 @@ def TypeConversionI64ToI32Module_basic(module, tu: TestUtils): class TypeConversionI1ToI32Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -94,7 +89,6 @@ def TypeConversionI1ToI32Module_basic(module, tu: TestUtils): class TypeConversionI1ToI64Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -111,7 +105,6 @@ def TypeConversionI1ToI64Module_basic(module, tu: TestUtils): class TypeConversionI1ToF32Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -128,7 +121,6 @@ def TypeConversionI1ToF32Module_basic(module, tu: TestUtils): class TypeConversionI1ToF64Module(torch.nn.Module): - def __init__(self): super().__init__() @@ -148,43 +140,46 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils): class ToDtypeLayoutNoneModule(torch.nn.Module): - def __init__(self): super().__init__() @export @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, x): - return torch.ops.aten.to(x, - dtype=torch.float64, - layout=None, - device=None, - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None) + return torch.ops.aten.to( + x, + dtype=torch.float64, + layout=None, + device=None, + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None, + ) @register_test_case(module_factory=lambda: ToDtypeLayoutNoneModule()) def ToDtypeLayoutNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) -class ToDtypeLayoutCPUModule(torch.nn.Module): +class ToDtypeLayoutCPUModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, x): - return torch.ops.aten.to(x, - dtype=torch.float64, - layout=None, - device="cpu", - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None) + return torch.ops.aten.to( + x, + dtype=torch.float64, + layout=None, + device="cpu", + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None, + ) @register_test_case(module_factory=lambda: ToDtypeLayoutCPUModule()) @@ -193,21 +188,22 @@ def ToDtypeLayoutCPUModule_basic(module, tu: TestUtils): class ToDtypeLayoutStridedModule(torch.nn.Module): - def __init__(self): super().__init__() @export @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, x): - return torch.ops.aten.to(x, - dtype=torch.float64, - layout=torch.strided, - device=None, - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None) + return torch.ops.aten.to( + x, + dtype=torch.float64, + layout=torch.strided, + device=None, + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None, + ) @register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule()) @@ -216,21 +212,22 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils): class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module): - def __init__(self): super().__init__() @export @annotate_args([None, ([3, 5], torch.int64, True)]) def forward(self, x): - return torch.ops.aten.to(x, - dtype=torch.bool, - layout=None, - device=None, - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None) + return torch.ops.aten.to( + x, + dtype=torch.bool, + layout=None, + device=None, + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None, + ) @register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule()) @@ -239,16 +236,17 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils): class TypeAsSameModule(torch.nn.Module): - def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, x, y): return torch.ops.aten.type_as(x, y) @@ -257,17 +255,19 @@ class TypeAsSameModule(torch.nn.Module): def TypeAsSameModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(3, 5)) -class TypeAsDifferentModule(torch.nn.Module): +class TypeAsDifferentModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.int, True), - ([-1, -1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.int, True), + ([-1, -1], torch.int64, True), + ] + ) def forward(self, x, y): return torch.ops.aten.type_as(x, y) @@ -276,14 +276,14 @@ class TypeAsDifferentModule(torch.nn.Module): def TypeAsDifferentModule_basic(module, tu: TestUtils): module.forward( tu.randint(3, 5, low=0, high=10, dtype=torch.int), - tu.randint(3, 5, low=0, high=10, dtype=torch.int64) + tu.randint(3, 5, low=0, high=10, dtype=torch.int64), ) + # ============================================================================== class PrimsConvertElementTypeModule(torch.nn.Module): - def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_promotion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_promotion.py index 41c03ec18..69abaf76c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_promotion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_promotion.py @@ -17,21 +17,22 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int32, True), - ([-1], torch.int64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ] + ) def forward(self, a, b): return torch.add(a, b, alpha=3) @register_test_case( - module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule()) + module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule() +) def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils): - module.forward( - tu.randint(4, high=10).type(torch.int32), - tu.randint(4, high=10)) + module.forward(tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10)) class TypePromotionDifferentCategoryModule(torch.nn.Module): @@ -39,17 +40,18 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([-1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ] + ) def forward(self, a, b): return torch.add(a, b, alpha=3) -@register_test_case( - module_factory=lambda: TypePromotionDifferentCategoryModule()) +@register_test_case(module_factory=lambda: TypePromotionDifferentCategoryModule()) def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, high=10), tu.rand(4)) @@ -59,17 +61,20 @@ class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([], torch.float64, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([], torch.float64, True), + ] + ) def forward(self, a, b): return torch.add(a, b, alpha=2.3) @register_test_case( - module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule()) + module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule() +) def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand().type(torch.float64)) @@ -79,17 +84,18 @@ class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.int64, True), + ([], torch.float32, True), + ] + ) def forward(self, a, b): return torch.add(a, b, alpha=2) -@register_test_case( - module_factory=lambda: TypePromotionZeroRankHigherCategoryModule()) +@register_test_case(module_factory=lambda: TypePromotionZeroRankHigherCategoryModule()) def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, high=10), tu.rand()) @@ -99,11 +105,13 @@ class TypePromotionAlphaWiderModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ([], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([], torch.float32, True), + ] + ) def forward(self, a, b): return torch.add(a, b, alpha=2.3) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py index 34f3b9c69..972aa3710 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py @@ -22,10 +22,12 @@ class ResNet18Module(torch.nn.Module): self.train(False) @export - @annotate_args([ - None, - ([-1, 3, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, 3, -1, -1], torch.float32, True), + ] + ) def forward(self, img): return self.resnet.forward(img) @@ -44,10 +46,12 @@ class ResNet18StaticModule(torch.nn.Module): self.train(False) @export - @annotate_args([ - None, - ([1, 3, 224, 224], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([1, 3, 224, 224], torch.float32, True), + ] + ) def forward(self, img): return self.resnet.forward(img) @@ -62,11 +66,13 @@ class IouOfModule(torch.nn.Module): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) def forward(self, bbox1, bbox2): area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1]) area2 = (bbox2[:, 2] - bbox2[:, 0]) * (bbox2[:, 3] - bbox2[:, 1]) @@ -94,10 +100,12 @@ class MobilenetV3Module(torch.nn.Module): self.train(False) @export - @annotate_args([ - None, - ([-1, 3, -1, -1], torch.float32, True), - ]) + @annotate_args( + [ + None, + ([-1, 3, -1, -1], torch.float32, True), + ] + ) def forward(self, img): return self.mobilenetv3.forward(img) diff --git a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/abc.py b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/abc.py index d56a86d1f..5c8d97bf2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/abc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/abc.py @@ -13,12 +13,12 @@ from torch_mlir.ir import Module # A type shared between the result of `TosaBackend.compile` and the # input to `TosaBackend.load`. Each backend will likely have a # different definition of this type. -CompiledArtifact = TypeVar('CompiledArtifact') +CompiledArtifact = TypeVar("CompiledArtifact") # A wrapper around a backend-specific loaded program representation # that uniformly translates the `x.method(...)` interface expected of # Torch modules into appropriate lower-level operations. -Invoker = TypeVar('Invoker') +Invoker = TypeVar("Invoker") class TosaBackend(abc.ABC): @@ -27,6 +27,7 @@ class TosaBackend(abc.ABC): Backends are recommended to raise meaningful exceptions in case of error, ideally with easy reproduction instructions. """ + @abc.abstractmethod def compile(self, module: Module) -> CompiledArtifact: """Compile the provided MLIR module into a compiled artifact. diff --git a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index 9317a3020..c9273c1f4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -7,7 +7,9 @@ from torch_mlir.ir import * from torch_mlir.passmanager import * from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) from .abc import TosaBackend @@ -17,23 +19,25 @@ __all__ = [ # The pipeline of func.func passes that lower the TOSA backend contract to the # Linalg-on-Tensors backend contract accepted by RefBackend. -TOSA_TO_LINALG_FUNC_PIPELINE = ",".join([ - # TOSA legalization may emit tosa.const() ops. These are legalized - # by tosa-to-arith to arith.constants. This mechanical transformation - # must be done prior to TOSA-to-LinAlg so that the latter does not fail. - # This is an artifact of legalizations spread across a collection of simple - # ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg, - # that depend on TOSA as well as TOSA-to-Standard. - "tosa-to-arith", - "tosa-to-scf", - # Named ops must be legalized prior to general tosa-to-linalg - "tosa-to-linalg-named", - # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them - # to arith.constants here before proceeding further. - "tosa-to-linalg", - "tosa-to-tensor", - "tosa-to-arith", -]) +TOSA_TO_LINALG_FUNC_PIPELINE = ",".join( + [ + # TOSA legalization may emit tosa.const() ops. These are legalized + # by tosa-to-arith to arith.constants. This mechanical transformation + # must be done prior to TOSA-to-LinAlg so that the latter does not fail. + # This is an artifact of legalizations spread across a collection of simple + # ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg, + # that depend on TOSA as well as TOSA-to-Standard. + "tosa-to-arith", + "tosa-to-scf", + # Named ops must be legalized prior to general tosa-to-linalg + "tosa-to-linalg-named", + # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them + # to arith.constants here before proceeding further. + "tosa-to-linalg", + "tosa-to-tensor", + "tosa-to-arith", + ] +) class LinalgOnTensorsTosaBackend(TosaBackend): @@ -60,7 +64,8 @@ class LinalgOnTensorsTosaBackend(TosaBackend): run_pipeline_with_repro_report( imported_module, f"builtin.module(func.func({TOSA_TO_LINALG_FUNC_PIPELINE}))", - "Lowering TOSA backend contract to Linalg-on-Tensors backend contract") + "Lowering TOSA backend contract to Linalg-on-Tensors backend contract", + ) return self.refbackend.compile(imported_module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/utils.py b/projects/pt1/python/torch_mlir_e2e_test/utils.py index e3a76581f..dd9f8d8f8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/utils.py +++ b/projects/pt1/python/torch_mlir_e2e_test/utils.py @@ -6,6 +6,7 @@ from torch_mlir.torchscript import TensorPlaceholder from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME + def convert_annotations_to_placeholders(forward_method): """Converts the annotations on a forward method into tensor placeholders. @@ -17,6 +18,7 @@ def convert_annotations_to_placeholders(forward_method): for annotation in annotations[1:]: if not annotation[2]: raise ValueError( - "Can only compile inputs annotated as having value semantics.") + "Can only compile inputs annotated as having value semantics." + ) placeholders.append(TensorPlaceholder(annotation[0], annotation[1])) return placeholders diff --git a/projects/pt1/test/lit.cfg.py b/projects/pt1/test/lit.cfg.py index 31e3ee388..2f2cfe656 100644 --- a/projects/pt1/test/lit.cfg.py +++ b/projects/pt1/test/lit.cfg.py @@ -19,62 +19,76 @@ from lit.llvm.subst import FindTool # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'TORCH_MLIR_PT1' +config.name = "TORCH_MLIR_PT1" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.py'] +config.suffixes = [".mlir", ".py"] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test') +config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") -config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) -llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) -#llvm_config.use_default_substitutions() +# llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. config.excludes = [ - 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt', - 'lit.cfg.py', 'lit.site.cfg.py' + "Inputs", + "Examples", + "CMakeLists.txt", + "README.txt", + "LICENSE.txt", + "lit.cfg.py", + "lit.site.cfg.py", ] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test') -config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin') +config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") +config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin") # Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) +llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) # Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree. -llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True) +llvm_config.with_environment( + "PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True +) # On Windows the path to python could contains spaces in which case it needs to # be provided in quotes. This is the equivalent of how %python is setup in # llvm/utils/lit/lit/llvm/config.py. if "Windows" in config.host_os: - config.python_executable = '"%s"' % (config.python_executable) + config.python_executable = '"%s"' % (config.python_executable) -tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir, config.torch_mlir_obj_root] +tool_dirs = [ + config.standalone_tools_dir, + config.llvm_tools_dir, + config.torch_mlir_obj_root, +] tools = [ - 'torch-mlir-opt', - ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), + "torch-mlir-opt", + ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"), ] llvm_config.add_tool_substitutions(tools, tool_dirs) if config.enable_bindings_python: - llvm_config.with_environment('PYTHONPATH', [ - os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'), - ], - append_path=True) + llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join(config.torch_mlir_python_packages_dir, "torch_mlir"), + ], + append_path=True, + ) diff --git a/projects/pt1/test/python/custom_op_shape_dtype_fn.py b/projects/pt1/test/python/custom_op_shape_dtype_fn.py index a3a2b965d..3c878dc0d 100644 --- a/projects/pt1/test/python/custom_op_shape_dtype_fn.py +++ b/projects/pt1/test/python/custom_op_shape_dtype_fn.py @@ -20,18 +20,26 @@ goofy_lib = torch.library.Library("goofy", "DEF") goofy_lib.define("identity(Tensor t) -> Tensor") goofy_lib.impl("identity", identity) + def goofy〇identity〡shape(t: List[int]) -> List[int]: return t + def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int: t_rank, t_dtype = t_rank_dtype return t_dtype + def goofy〇identity〡has_value_semantics() -> None: return + extra_library = [ - goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics] + goofy〇identity〡shape, + goofy〇identity〡dtype, + goofy〇identity〡has_value_semantics, +] + class CustomOpExampleModule(torch.nn.Module): def __init__(self): @@ -52,6 +60,7 @@ class CustomOpExampleModule(torch.nn.Module): mod = CustomOpExampleModule() mod.eval() + def run(): mod = CustomOpExampleModule() mod.eval() @@ -66,6 +75,7 @@ def run(): print(module) + run() # CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} { diff --git a/projects/pt1/test/python/importer/jit_ir/get_registered_ops.py b/projects/pt1/test/python/importer/jit_ir/get_registered_ops.py index b4cdb4b27..13b78345a 100644 --- a/projects/pt1/test/python/importer/jit_ir/get_registered_ops.py +++ b/projects/pt1/test/python/importer/jit_ir/get_registered_ops.py @@ -9,4 +9,4 @@ from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # This check is just for a built-in op that is unlikely to change (and is # otherwise insignificant). # CHECK: {'name': ('aten::mul', 'Tensor'), 'is_c10_op': True, 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]} -print('\n\n'.join([repr(r) for r in get_registered_ops()])) +print("\n\n".join([repr(r) for r in get_registered_ops()])) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py index 0979d0422..6ca4b19d8 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py @@ -6,6 +6,7 @@ import typing import torch from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder + # RUN: %PYTHON %s | FileCheck %s mb = ModuleBuilder() @@ -31,7 +32,7 @@ except Exception as e: print(e) try: - annotator.annotateArgs(class_type, ['forward'], [None]) + annotator.annotateArgs(class_type, ["forward"], [None]) except Exception as e: # CHECK: There must be one argument annotation per function parameter. # CHECK-SAME: Including 'self' the number of argument annotations is: 1. @@ -40,7 +41,7 @@ except Exception as e: print(e) try: - annotator.annotateArgs(class_type, ['forward'], [None, ([3, 4], 42, False)]) + annotator.annotateArgs(class_type, ["forward"], [None, ([3, 4], 42, False)]) except Exception as e: # This is just the raw repr of the object in quotes. # CHECK: unsupported scalar type '42' diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py index 6cc2d57b1..9778828f5 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py @@ -6,16 +6,20 @@ import typing import torch from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder + # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() + def forward(self, a, b): return + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) @@ -26,11 +30,15 @@ class_type = recursivescriptmodule._c._type() # CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}, # CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>} # CHECK-SAME: ) -> !torch.none -annotator.annotateArgs(class_type, ['forward'], [ - None, - ((-1, 1024), torch.int8, True), - ((), torch.float, True), -]) +annotator.annotateArgs( + class_type, + ["forward"], + [ + None, + ((-1, 1024), torch.int8, True), + ((), torch.float, True), + ], +) # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. mb.import_module(recursivescriptmodule._c, annotator) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py index 3a2ed4319..6008a6e8f 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py @@ -6,6 +6,7 @@ import typing import torch from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder + # RUN: %PYTHON %s | FileCheck %s mb = ModuleBuilder() @@ -40,13 +41,17 @@ annotator = ClassAnnotator() class_type = recursivescriptmodule._c._type() annotator.exportNone(class_type) -annotator.exportPath(class_type, ['s', 'exported']) -annotator.exportPath(class_type, ['s', 'forward']) -annotator.annotateArgs(class_type, ['forward'], [ - None, - ((1024, 2), torch.float32, False), - ((42, -1, 7), torch.int8, True), -]) +annotator.exportPath(class_type, ["s", "exported"]) +annotator.exportPath(class_type, ["s", "forward"]) +annotator.annotateArgs( + class_type, + ["forward"], + [ + None, + ((1024, 2), torch.float32, False), + ((42, -1, 7), torch.int8, True), + ], +) # "Change detector" test + "documentation" for the repr of `ClassAnnotator`. # This is semi-load-bearing because users interact with this class and repr diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py index 2a0806f6f..424d97007 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py @@ -6,16 +6,20 @@ import typing import torch from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder + # RUN: %PYTHON %s | FileCheck %s mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() + def forward(self): return + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) @@ -23,7 +27,7 @@ annotator = ClassAnnotator() class_type = recursivescriptmodule._c._type() try: - annotator.exportPath(class_type, ['a']) + annotator.exportPath(class_type, ["a"]) except Exception as e: # CHECK: class '__torch__.TestModule' does not have a method or attribute called 'a' print(e) @@ -34,7 +38,7 @@ except Exception as e: print(e) try: - annotator.exportPath(class_type, ['a', 'b']) + annotator.exportPath(class_type, ["a", "b"]) except Exception as e: # This error is generated by PyTorch itself, so be a bit defensive about changes. # CHECK: __torch__.TestModule {{.*}} 'a' diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py index 79b4dccd2..c27dbd73a 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py @@ -6,6 +6,7 @@ import typing import torch from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder + # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() @@ -16,18 +17,23 @@ class Submodule(torch.nn.Module): super().__init__() self.exported = 1 self.not_exported = 2 + def forward(self): return self.not_exported_method() + def not_exported_method(self): return + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.s = Submodule() + def forward(self): return self.s.forward() + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) @@ -40,8 +46,8 @@ class_type = recursivescriptmodule._c._type() # CHECK: torch.method private "not_exported_method", @{{.*}} # CHECK: } annotator.exportNone(class_type) -annotator.exportPath(class_type, ['s', 'exported']) -annotator.exportPath(class_type, ['s', 'forward']) +annotator.exportPath(class_type, ["s", "exported"]) +annotator.exportPath(class_type, ["s", "forward"]) # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. mb.import_module(recursivescriptmodule._c, annotator) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py index 433f8249b..c6ac5c95b 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py @@ -6,20 +6,25 @@ import typing import torch from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder + # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.exported = 1 self.not_exported = 2 + def forward(self): return self.not_exported_method() + def not_exported_method(self): return + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) @@ -32,8 +37,8 @@ class_type = recursivescriptmodule._c._type() # CHECK: torch.method private "not_exported_method", @{{.*}} # CHECK: } annotator.exportNone(class_type) -annotator.exportPath(class_type, ['exported']) -annotator.exportPath(class_type, ['forward']) +annotator.exportPath(class_type, ["exported"]) +annotator.exportPath(class_type, ["forward"]) # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. mb.import_module(recursivescriptmodule._c, annotator) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py index 399b45f73..bd21c4e8b 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py @@ -16,6 +16,7 @@ class TestModule(torch.nn.Module): def __init__(self): super().__init__() + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py index 117b0cff9..b40e6e456 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py @@ -13,9 +13,9 @@ mb = ModuleBuilder() class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.d = {"key1": torch.tensor(1)} + def __init__(self): + super().__init__() + self.d = {"key1": torch.tensor(1)} # CHECK: torch.class_type @[[CLASSTYPE:.*]] { diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py index 318e09975..ffd017edb 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py @@ -30,18 +30,23 @@ mb = ModuleBuilder() # CHECK: return %[[RET]] : !torch.none # CHECK: } -def calls_method(c: 'TestModule', x): + +def calls_method(c: "TestModule", x): return c.method(x) + class TestModule(torch.nn.Module): def __init__(self): super().__init__() + def forward(self, x): return calls_method(self, x) - @torch.jit.export # Needed so that scripting sees it. + + @torch.jit.export # Needed so that scripting sees it. def method(self, x): return + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py index ee22a495e..13d326926 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py @@ -26,15 +26,19 @@ mb = ModuleBuilder() # CHECK: torch.method "forward", @__torch__.TestModule.forward # CHECK: } + def identity(x): return x + class TestModule(torch.nn.Module): def __init__(self): super().__init__() + def forward(self, x): return identity(x) + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py index 0c1b8f2ff..cff6ec0cc 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py @@ -11,10 +11,13 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.l = [1, 2] + + # CHECK: torch.class_type @[[CLASSTYPE:.*]] { # CHECK: torch.attr "l" : !torch.list # CHECK: } diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py index fee1b2922..c0c52304a 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py @@ -12,21 +12,22 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() + + # CHECK-LABEL: func.func private @__torch__.TestModule.forward( + # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional { + # CHECK: %[[NONE:.*]] = torch.constant.none + # CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional + # CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional) -> !torch.optional + # CHECK: return %[[RET]] : !torch.optional + def forward(self): + return self.callee(None) + + def callee(self, o: typing.Optional[int]): + return o - # CHECK-LABEL: func.func private @__torch__.TestModule.forward( - # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional { - # CHECK: %[[NONE:.*]] = torch.constant.none - # CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional - # CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional) -> !torch.optional - # CHECK: return %[[RET]] : !torch.optional - def forward(self): - return self.callee(None) - def callee(self, o: typing.Optional[int]): - return o test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py index 5d38d6e3a..75ed79d45 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py @@ -11,13 +11,16 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): - # CHECK-LABEL: torch.nn_module - # CHECK: loc("{{.*}}methods-locations.py":[[@LINE+1]] - return x * y + def __init__(self): + super().__init__() + + def forward(self, x, y): + # CHECK-LABEL: torch.nn_module + # CHECK: loc("{{.*}}methods-locations.py":[[@LINE+1]] + return x * y + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py index 0143012bf..510fa66d3 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py @@ -30,11 +30,14 @@ mb = ModuleBuilder() # CHECK: torch.method "forward", @__torch__.TestModule.forward # CHECK: } + class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): - return x + def __init__(self): + super().__init__() + + def forward(self, x): + return x + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py index eae86ec1c..154b59381 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py @@ -11,15 +11,17 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class Submodule(torch.nn.Module): def __init__(self): super().__init__() - self.t1 = torch.tensor([10., 20.]) + self.t1 = torch.tensor([10.0, 20.0]) # Test a nontrivial recursive case of the diagnostic. # CHECK: Unhandled tensor that shares storage with another tensor. # CHECK-NEXT: Found at path '.m.t2' from root object '__torch__.TestModule' self.t2 = self.t1[0] + class TestModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py index 968509acc..951de6665 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py @@ -11,12 +11,13 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() # CHECK: Unhandled tensor that shares storage with another tensor. # CHECK-NEXT: Found at path '.t2' from root object '__torch__.TestModule' - self.t1 = torch.tensor([10., 20.]) + self.t1 = torch.tensor([10.0, 20.0]) self.t2 = self.t1[0] diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py index 0f6516a27..48bcb084b 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py @@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -18,7 +19,7 @@ class TestModule(torch.nn.Module): # CHECK: torch.nn_module { # CHECK: torch.slot "t1", %[[T]] # CHECK: torch.slot "t2", %[[T]] - self.t1 = self.t2 = torch.tensor([10., 20.]) + self.t1 = self.t2 = torch.tensor([10.0, 20.0]) test_module = TestModule() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py index e48c327ed..69a947df7 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py @@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -25,6 +26,7 @@ class TestModule(torch.nn.Module): self.t1 = self.t2 # CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}}) self.callee(self.t1, self.t2) + # CHECK-LABEL: func.func private @__torch__.TestModule.callee( # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">, # CHECK-SAME: %[[X:.*]]: !torch.tensor, @@ -32,6 +34,7 @@ class TestModule(torch.nn.Module): def callee(self, x, y): pass + test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py index 3cb8cf992..273af8fe1 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py @@ -11,12 +11,14 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.i = 3 self.f = 42.5 + # CHECK: torch.class_type @[[CLASSTYPE:.*]] { # CHECK: torch.attr "training" : !torch.bool # CHECK: torch.attr "i" : !torch.int diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py index d77b98323..e33985fac 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py @@ -17,10 +17,10 @@ class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.quantized.Linear(5, 2, dtype=torch.qint8) - self.linear_no_bias = torch.nn.quantized.Linear(6, - 2, - bias_=False, - dtype=torch.qint8) + self.linear_no_bias = torch.nn.quantized.Linear( + 6, 2, bias_=False, dtype=torch.qint8 + ) + # CHECK: %[[SCALE:.*]] = torch.constant.float # CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0 # CHECK: %[[INT_REPR:.*]] = torch.tensor.literal({{.*}}) : !torch.tensor<[2,5],si8> diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py index b65d6f5ca..7d704875b 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py @@ -11,10 +11,13 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.s = "foo" + + # CHECK: torch.class_type @[[CLASSTYPE:.*]] { # TODO: Don't lose element type. # CHECK: torch.attr "s" : !torch.str diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py index 5b2cf04b5..e3c341d9e 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py @@ -11,12 +11,15 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class Submodule(torch.nn.Module): def __init__(self, n): super().__init__() self.n = n + def forward(self): - return self.n + return self.n + class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py index d9983628d..e306723fa 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py @@ -11,17 +11,20 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class Submodule(torch.nn.Module): def __init__(self, n): super().__init__() self.n = n + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.s0 = Submodule(0) self.s1 = Submodule(1) + # CHECK-LABEL: torch.class_type @__torch__.TestModule { # CHECK: %[[T:.*]] = torch.constant.bool true diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py index 36dfa32f0..c6edaa520 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py @@ -11,13 +11,15 @@ from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuil mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.ones_i32 = torch.ones(1, dtype=torch.int32) - self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) + self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) self.arange = torch.nn.Parameter(torch.arange(3.0)) + # CHECK: %[[ARANGE:.*]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32> # CHECK: %[[ONES_I32:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi32>) : !torch.vtensor<[1],si32> # CHECK: %[[ONES_QINT8_DATA:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi8>) : !torch.vtensor<[1],si8> diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py index 31a89e3e1..b0f2daacc 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py @@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -23,15 +24,20 @@ class TestModule(torch.nn.Module): # Because bools turn anything that is non-zero into `True`, it is # important to check a series of `True`s and `False`s to make sure the # actual values are being imported rather than just garbage. - self.bool_ = torch.tensor([True, False, True, False, True, False], dtype=torch.bool) + self.bool_ = torch.tensor( + [True, False, True, False, True, False], dtype=torch.bool + ) self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16) self.ones_f16 = torch.ones(1, dtype=torch.half) self.ones_ui8 = torch.ones(1, dtype=torch.uint8) self.ones_i8 = torch.ones(1, dtype=torch.int8) - self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) - self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8) + self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) + self.ones_quint8 = torch.quantize_per_tensor( + torch.ones(1), 1.0, 0, torch.quint8 + ) self.arange = torch.nn.Parameter(torch.arange(3.0)) + # CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32> # CHECK: %[[ONES:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor<[1],f32> # CHECK: %[[ONES_I32:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi32>) : !torch.tensor<[1],si32> diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py index 7bed706ac..76e4e4bd1 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py @@ -11,10 +11,13 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.t = (1, 2) + + # CHECK: torch.class_type @[[CLASSTYPE:.*]] { # TODO: Don't lose element type. # CHECK: } diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/classes.py b/projects/pt1/test/python/importer/jit_ir/node_import/classes.py index 09e2b1b0b..3f791c868 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/classes.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/classes.py @@ -14,10 +14,12 @@ import typing mb = ModuleBuilder() + class BasicClass: def __init__(self, x: int): self.x = x + # CHECK-LABEL: func.func @__torch__.prim_CreateObject( # CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.nn.Module<"__torch__.BasicClass"> { # CHECK: %[[OBJECT:.*]] = torch.prim.CreateObject !torch.nn.Module<"__torch__.BasicClass"> @@ -28,5 +30,6 @@ class BasicClass: def prim_CreateObject(i: int): return BasicClass(i) + mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py index bb6ab4ce4..1bc258a42 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py @@ -17,10 +17,11 @@ mb = ModuleBuilder() @mb.import_function @torch.jit.script def add3(t0, t1, t2): - # CHECK-DAG: torch.aten.add.Tensor {{.*}} loc("aten::add"({{.*}}debug-info.py":[[# @LINE + 1]] - intermediate = t0 + t1 - # CHECK-DAG: torch.aten.mul.Tensor {{.*}} loc("aten::mul"({{.*}}debug-info.py":[[# @LINE + 1]] - return intermediate * t2 + # CHECK-DAG: torch.aten.add.Tensor {{.*}} loc("aten::add"({{.*}}debug-info.py":[[# @LINE + 1]] + intermediate = t0 + t1 + # CHECK-DAG: torch.aten.mul.Tensor {{.*}} loc("aten::mul"({{.*}}debug-info.py":[[# @LINE + 1]] + return intermediate * t2 + # Verify again with debug info present. Just checking that it makes it in there. mb.module.operation.print(enable_debug_info=True, use_local_scope=True) diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/dict.py b/projects/pt1/test/python/importer/jit_ir/node_import/dict.py index 0060357b4..5385258b1 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/dict.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/dict.py @@ -18,7 +18,7 @@ mb = ModuleBuilder() @mb.import_function @torch.jit.script def dict_literal_empty() -> Dict[str, torch.Tensor]: - return {} + return {} # CHECK-LABEL: func.func @__torch__.dict_literal( @@ -33,10 +33,9 @@ def dict_literal_empty() -> Dict[str, torch.Tensor]: # CHECK: } @mb.import_function @torch.jit.script -def dict_literal(k0: str, v0, k1: str, - v1) -> Dict[str, Optional[torch.Tensor]]: - my_dict: Dict[str, Optional[torch.Tensor]] = {k0: v0, k1: v1} - return my_dict +def dict_literal(k0: str, v0, k1: str, v1) -> Dict[str, Optional[torch.Tensor]]: + my_dict: Dict[str, Optional[torch.Tensor]] = {k0: v0, k1: v1} + return my_dict mb.module.operation.print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py index 71853b0c0..5ee16e391 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py @@ -28,6 +28,7 @@ def f(b: bool, i: int): else: return i * i + assert isinstance(f, torch.jit.ScriptFunction) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/errors.py b/projects/pt1/test/python/importer/jit_ir/node_import/errors.py index 2ac801bdd..64f490649 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/errors.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/errors.py @@ -9,8 +9,8 @@ from torch_mlir.jit_ir_importer import ModuleBuilder class Color(enum.Enum): - RED = 1 - GREEN = 2 + RED = 1 + GREEN = 2 # RUN: %PYTHON %s @@ -20,13 +20,13 @@ mb = ModuleBuilder() # To test errors, use a type that we don't support yet. try: - @mb.import_function - @torch.jit.script - def import_class(x: Color): - return x + @mb.import_function + @torch.jit.script + def import_class(x: Color): + return x + except Exception as e: - # TODO: Once diagnostics are enabled, verify the actual error emitted. - assert str( - e) == "unsupported type in function schema: 'Enum<__torch__.Color>'" + # TODO: Once diagnostics are enabled, verify the actual error emitted. + assert str(e) == "unsupported type in function schema: 'Enum<__torch__.Color>'" else: - assert False, "Expected exception" + assert False, "Expected exception" diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py b/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py index a724f1185..1cb789345 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py @@ -15,10 +15,15 @@ mb = ModuleBuilder() # CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32> # CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor # CHECK: return %[[RESULT]] : !torch.tensor -mb.import_function(create_script_function("__torch__.refined_block_arg", """ +mb.import_function( + create_script_function( + "__torch__.refined_block_arg", + """ graph(%0 : Float(1, 384)): return (%0) -""")) +""", + ) +) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py index 89f5604bf..2acde08ca 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py @@ -20,6 +20,7 @@ mb = ModuleBuilder() def optional_return(i: int) -> typing.Optional[int]: return i + # CHECK-LABEL: func.func @__torch__.optional_arg( # CHECK-SAME: %[[ARG:.*]]: !torch.optional) -> !torch.none { @mb.import_function @@ -27,6 +28,7 @@ def optional_return(i: int) -> typing.Optional[int]: def optional_arg(i: typing.Optional[int]) -> None: return + # CHECK-LABEL: func.func @__torch__.calls_optional_arg( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.none { # CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional) -> !torch.none diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/if.py b/projects/pt1/test/python/importer/jit_ir/node_import/if.py index 8289e0503..86390f707 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/if.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/if.py @@ -32,6 +32,7 @@ def prim_If(b: bool, i: int): else: return i * i + # CHECK-LABEL: func.func @__torch__.prim_If_derefine( # CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.optional { @@ -51,5 +52,6 @@ def prim_If_derefine(b: bool, i: int): return None return i + mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/list.py b/projects/pt1/test/python/importer/jit_ir/node_import/list.py index 2b30d545b..392280a4d 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/list.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/list.py @@ -15,11 +15,13 @@ mb = ModuleBuilder() # CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list # CHECK: return %[[RET]] : !torch.list + @mb.import_function @torch.jit.script def f(t0, t1): - return [t0, t1] - + return [t0, t1] + + assert isinstance(f, torch.jit.ScriptFunction) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py index d6bb141f2..d432cd6ee 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py @@ -29,6 +29,7 @@ def prim_Loop_forlike(n: int): f += i return f + # CHECK-LABEL: func.func @__torch__.prim_Loop_whilelike( # CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.float { # CHECK: %[[F_INIT:.*]] = torch.constant.float 3.200000e+00 @@ -49,6 +50,7 @@ def prim_Loop_whilelike(n: int): f = f * f return f + # CHECK-LABEL: func.func @__torch__.prim_Loop_derefine( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional { # CHECK: %[[TRUE:.*]] = torch.constant.bool true @@ -68,5 +70,6 @@ def prim_Loop_derefine(n: int): x = n return x + mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py index 07a56616e..66959257e 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py @@ -25,6 +25,7 @@ mb = ModuleBuilder() def prim_NumToTensor(i: int): return _to_tensor(i) + # CHECK-LABEL: func.func @__torch__.prim_Print( # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.none { # CHECK: %[[STR:.*]] = torch.constant.str "x" @@ -34,6 +35,7 @@ def prim_NumToTensor(i: int): def prim_Print(x): print("x", x) + # CHECK-LABEL: func.func @__torch__.prim_RaiseException() -> !torch.none { # CHECK: %[[ERRORSTR:.*]] = torch.constant.str "Error" # CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !torch.none @@ -44,6 +46,7 @@ def prim_Print(x): def prim_RaiseException(): raise Exception("Error") + # CHECK-LABEL: func.func @__torch__.prim_unchecked_cast( # CHECK-SAME: %[[ARG:.*]]: !torch.optional) -> !torch.int { # CHECK: %[[NONE:.*]] = torch.constant.none @@ -63,6 +66,7 @@ def prim_unchecked_cast(i: typing.Optional[int]): return 3 return i + # CHECK-LABEL: func.func @__torch__.prim_TupleUnpack( # CHECK-SAME: %[[ARG:.*]]: !torch.tuple) -> !torch.int { # CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple -> !torch.int, !torch.int @@ -73,6 +77,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]): val, _ = tup return val + # CHECK-LABEL: func.func @__torch__.prim_TupleIndex( # CHECK-SAME: %[[ARG:.*]]: !torch.tuple) -> !torch.tensor { # CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple, !torch.int -> !torch.tensor @@ -82,6 +87,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]): def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]): return tup[0] + # CHECK-LABEL: func.func @__torch__.prim_ListUnpack( # CHECK-SAME: %[[ARG:.*]]: !torch.list) -> !torch.int { # CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list -> !torch.int, !torch.int @@ -92,6 +98,7 @@ def prim_ListUnpack(l: typing.List[int]): _, val, _ = l return val + # CHECK-LABEL: func.func @__torch__.prim_dtype( # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int { # CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !torch.tensor -> !torch.int @@ -101,6 +108,7 @@ def prim_ListUnpack(l: typing.List[int]): def prim_dtype(x): return x.dtype + # CHECK-LABEL: func.func @__torch__.prim_layout( # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int { # CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !torch.tensor -> !torch.int @@ -110,6 +118,7 @@ def prim_dtype(x): def prim_layout(x): return x.layout + # CHECK-LABEL: func.func @__torch__.prim_device( # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.Device { # CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !torch.tensor -> !torch.Device @@ -119,6 +128,7 @@ def prim_layout(x): def prim_device(x): return x.device + # CHECK-LABEL: func.func @__torch__.prim_min( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple { # CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list @@ -131,7 +141,8 @@ def prim_device(x): @mb.import_function @torch.jit.script def prim_min(x: int): - return min(x), min(x,x), min(x, x, x) + return min(x), min(x, x), min(x, x, x) + # CHECK-LABEL: func.func @__torch__.prim_max( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple { @@ -145,7 +156,8 @@ def prim_min(x: int): @mb.import_function @torch.jit.script def prim_max(x: int): - return max(x), max(x,x), max(x, x, x) + return max(x), max(x, x), max(x, x, x) + # CHECK-LABEL: func.func @__torch__.prim_Constant_list() -> !torch.list { # CHECK: %[[A:.*]] = torch.constant.int 1 @@ -154,11 +166,16 @@ def prim_max(x: int): # CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[A]], %[[B]], %[[C]] : # CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list # CHECK: return %[[RET]] : !torch.list -mb.import_function(create_script_function("__torch__.prim_Constant_list", """ +mb.import_function( + create_script_function( + "__torch__.prim_Constant_list", + """ graph(): %list : int[] = prim::Constant[value=[1, 2, 3]]() return (%list) -""")) +""", + ) +) mb.module.operation.print() print() @@ -169,12 +186,19 @@ print() # CHECK: return %[[RET]] : !torch.number import_options = ImportOptions() import_options.assumeTensorsHaveValueSemantics = False -mb.import_function(create_script_function("__torch__.prim_Constant_scalar", """ +mb.import_function( + create_script_function( + "__torch__.prim_Constant_scalar", + """ graph(): %0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() %1 : Scalar = aten::ScalarImplicit(%0) return (%1) -""", parse_tensor_constants=True), import_options) +""", + parse_tensor_constants=True, + ), + import_options, +) mb.module.operation.print() print() @@ -184,12 +208,19 @@ print() # CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit # CHECK: return %[[RET]] : !torch.number import_options.assumeTensorsHaveValueSemantics = True -mb.import_function(create_script_function("__torch__.prim_Constant_scalar_value_semantics", """ +mb.import_function( + create_script_function( + "__torch__.prim_Constant_scalar_value_semantics", + """ graph(): %0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() %1 : Scalar = aten::ScalarImplicit(%0) return (%1) -""", parse_tensor_constants=True), import_options) +""", + parse_tensor_constants=True, + ), + import_options, +) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py index 2dff435cd..a1f06c390 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py @@ -11,8 +11,7 @@ from utils import create_script_function # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() -NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]), - ('f2', Optional[torch.Tensor])]) +NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])]) # CHECK-LABEL: func.func @__torch__.tuple( # CHECK-SAME: %[[T0:.*]]: !torch.tensor, @@ -24,7 +23,7 @@ NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]), @mb.import_function @torch.jit.script def tuple(t0, t1): - return t0, t1 + return t0, t1 # CHECK-LABEL: func.func @__torch__.tuple_optional( @@ -39,9 +38,8 @@ def tuple(t0, t1): # CHECK: return %[[RET]] : !torch.tuple, optional> @mb.import_function @torch.jit.script -def tuple_optional( - t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - return t0, t1 +def tuple_optional(t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + return t0, t1 # CHECK-LABEL: func.func @__torch__.namedtuple_optional( @@ -55,8 +53,9 @@ def tuple_optional( @mb.import_function @torch.jit.script def namedtuple_optional( - t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - return NT(t0, t1) + t0, t1 +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + return NT(t0, t1) # CHECK-LABEL: func.func @__torch__.tuple_construct_arg_needs_refinement( @@ -67,12 +66,17 @@ def namedtuple_optional( # CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0_REFINED]], %[[T1_REFINED]] : !torch.tensor<[4],f32>, !torch.tensor<[3],f32> -> !torch.tuple, tensor<[3],f32>> # CHECK: %[[DEREFINED:.*]] = torch.derefine %[[TUPLE]] : !torch.tuple, tensor<[3],f32>> to !torch.tuple # CHECK: return %[[DEREFINED]] : !torch.tuple -mb.import_function(create_script_function("__torch__.tuple_construct_arg_needs_refinement", """ +mb.import_function( + create_script_function( + "__torch__.tuple_construct_arg_needs_refinement", + """ graph(%t0 : Tensor, %t1 : Tensor): %10 : (Float(4), Float(3)) = prim::TupleConstruct(%t1, %t0) return (%10) -""")) +""", + ) +) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py index 8da5e0e2c..0a27692fc 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py @@ -17,6 +17,7 @@ def returns_bool(): # CHECK-NEXT: return %[[T]] return True + assert isinstance(returns_bool, torch.jit.ScriptFunction) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py index a0e86a66a..16a3359da 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py @@ -17,6 +17,7 @@ def returns_none(): # CHECK-NEXT: return %[[NONE]] pass + assert isinstance(returns_none, torch.jit.ScriptFunction) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py index 533ef7586..06a7cc7fc 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py @@ -3,6 +3,7 @@ from torch_mlir import torchscript # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s + class Inner(object): # CHECK-LABEL: func.func private @__torch__.Inner.foo( # CHECK-SAME: %[[ARG:.*]]: !torch.nn.Module<"__torch__.Inner">) { @@ -39,6 +40,7 @@ class Model(torch.nn.Module): with torch.no_grad(): return data + output_type = torchscript.OutputType.RAW mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) print(mod) diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/union.py b/projects/pt1/test/python/importer/jit_ir/node_import/union.py index 14eb41a21..eeaee94bf 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/union.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/union.py @@ -14,11 +14,13 @@ mb = ModuleBuilder() # CHECK-LABEL: func.func @__torch__.f( # CHECK-SAME: %{{.*}}: !torch.union) -> !torch.none { + @mb.import_function @torch.jit.script def f(x: Union[int, float]): - return - + return + + assert isinstance(f, torch.jit.ScriptFunction) mb.module.operation.print() print() diff --git a/projects/pt1/test/python/smoketest.py b/projects/pt1/test/python/smoketest.py index bb97927e9..77225a868 100644 --- a/projects/pt1/test/python/smoketest.py +++ b/projects/pt1/test/python/smoketest.py @@ -4,9 +4,9 @@ import torch_mlir.ir from torch_mlir.dialects import torch with torch_mlir.ir.Context() as ctx: - torch.register_dialect(ctx) - with torch_mlir.ir.Location.unknown() as loc: - module = torch_mlir.ir.Module.create(loc) - with torch_mlir.ir.InsertionPoint.at_block_begin(module.body): - n = torch.ConstantNoneOp() - module.operation.print() \ No newline at end of file + torch.register_dialect(ctx) + with torch_mlir.ir.Location.unknown() as loc: + module = torch_mlir.ir.Module.create(loc) + with torch_mlir.ir.InsertionPoint.at_block_begin(module.body): + n = torch.ConstantNoneOp() + module.operation.print() diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index 6416b88aa..4e5a2f8f8 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -26,17 +26,18 @@ def get_module_name_for_debug_dump(module): class TorchMlirCompilerError(Exception): pass -def run_pipeline_with_repro_report(module, - pipeline: str, - description: str, - enable_ir_printing: bool = False): + +def run_pipeline_with_repro_report( + module, pipeline: str, description: str, enable_ir_printing: bool = False +): """Runs `pipeline` on `module`, with a nice repro report if it fails.""" module_name = get_module_name_for_debug_dump(module) original_stderr = sys.stderr try: sys.stderr = StringIO() asm_for_error_report = module.operation.get_asm( - large_elements_limit=10, enable_debug_info=True) + large_elements_limit=10, enable_debug_info=True + ) # Lower module in place to make it ready for compiler backends. with module.context as ctx: pm = PassManager.parse(pipeline) @@ -54,9 +55,9 @@ def run_pipeline_with_repro_report(module, # - if we do have have colliding filenames, writes should at least # avoid being racy. filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write(asm_for_error_report) - debug_options="-mlir-print-ir-after-all -mlir-disable-threading" + debug_options = "-mlir-print-ir-after-all -mlir-disable-threading" # Put something descriptive here even if description is empty. description = description or f"{module_name} compile" @@ -70,7 +71,7 @@ def run_pipeline_with_repro_report(module, $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} Add '{debug_options}' to get the IR dump for debugging purpose. """ - trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')]) + trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")]) raise TorchMlirCompilerError(trimmed_message) from None finally: sys.stderr = original_stderr @@ -118,8 +119,10 @@ class OutputType(Enum): return spec spec = spec.upper().replace("-", "_") if spec not in OutputType.__members__: - raise ValueError(f"For output_type= argument, expected one of: " - f"{', '.join(OutputType.__members__.keys())}") + raise ValueError( + f"For output_type= argument, expected one of: " + f"{', '.join(OutputType.__members__.keys())}" + ) return OutputType[spec] @@ -134,8 +137,10 @@ def lower_mlir_module(verbose, output_type, module): if output_type == OutputType.TOSA: run_pipeline_with_repro_report( - module, "builtin.module(torch-backend-to-tosa-backend-pipeline)", - "Lowering Torch Backend IR -> TOSA Backend IR") + module, + "builtin.module(torch-backend-to-tosa-backend-pipeline)", + "Lowering Torch Backend IR -> TOSA Backend IR", + ) if verbose: print("\n====================") print("TOSA Backend IR") @@ -146,7 +151,8 @@ def lower_mlir_module(verbose, output_type, module): run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", - "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") + "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) if verbose: print("\n====================") print("LINALG Backend IR") @@ -157,7 +163,8 @@ def lower_mlir_module(verbose, output_type, module): run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", - "Lowering Torch Backend IR -> StableHLO Backend IR") + "Lowering Torch Backend IR -> StableHLO Backend IR", + ) if verbose: print("\n====================") print("StableHLO Backend IR") diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index 47a79f955..e049a0149 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -46,5 +46,6 @@ DEFAULT_DECOMPOSITIONS = [ torch.ops.aten._unsafe_index.Tensor, ] + def get_decomposition_table(): return get_decompositions(DEFAULT_DECOMPOSITIONS) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index d70c6046e..c1eec37aa 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1540,7 +1540,9 @@ class GraphNodeImporter: if op_name is None: return val with loc: - return Operation.create(name=op_name, results=[result_type], operands=[val]).result + return Operation.create( + name=op_name, results=[result_type], operands=[val] + ).result def _import_literal(self, py_value: Any) -> Value: # Apply the conversion callback. diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index f53922d2a..f1064f976 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -75,6 +75,7 @@ from ..dialects import ( func as func_dialect, ) + @dataclass class Config: """Various configuration settings for the importer.""" @@ -274,13 +275,11 @@ class NodeImporter: if func: func_dialect.ReturnOp(outputs) else: - Operation.create( - name="torch.operator_terminator", - operands=outputs) + Operation.create(name="torch.operator_terminator", operands=outputs) def get_none(self): - if '' in self._nv_map: - return self._nv_map[''] + if "" in self._nv_map: + return self._nv_map[""] with InsertionPoint(self._b), Location.name("onnx_importer.none"): nne = Operation.create( @@ -289,7 +288,7 @@ class NodeImporter: operands=[], attributes={}, ).results[0] - self._nv_map[''] = nne + self._nv_map[""] = nne return nne def import_node(self, node: onnx.NodeProto): @@ -328,7 +327,7 @@ class NodeImporter: results=output_types, operands=input_values, attributes=attrs, - regions=regions + regions=regions, ) self.import_regions(node.attribute, custom_op) @@ -375,12 +374,18 @@ class NodeImporter: for name, region in zip(sorted(attr_map.keys()), op.regions): attr = attr_map[name] - block_types = [self._cc.type_proto_to_type(input.type) for input in attr.g.input] + block_types = [ + self._cc.type_proto_to_type(input.type) for input in attr.g.input + ] block_names = [input.name for input in attr.g.input] - region.blocks.append(*block_types, arg_locs=[op.location] * len(block_types)) + region.blocks.append( + *block_types, arg_locs=[op.location] * len(block_types) + ) block = region.blocks[0] graph_info = GraphInfo(None, attr.g) - imp = NodeImporter(graph_info, parent_op=op, block=block, context_cache=self._cc) + imp = NodeImporter( + graph_info, parent_op=op, block=block, context_cache=self._cc + ) for node_name, input_value in zip(block_names, block.arguments): imp._nv_map[node_name] = input_value @@ -389,7 +394,9 @@ class NodeImporter: imp.import_all(False) - def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value: + def import_initializer( + self, initializer: onnx.TensorProto, extern_name: str = None + ) -> Value: # If an explicitly specified name is given, use that; otherwise, pick # up the name from the tensor proto itself iname = extern_name if extern_name else initializer.name @@ -445,6 +452,7 @@ class NodeImporter: self._gi.initializer_map[const_name] = value_proto.t return True + class ContextCache: """Caches per-context lookups of various things.""" @@ -459,8 +467,8 @@ class ContextCache: def __init__(self, context: Context): self._c = context self._elem_type_map: Dict[int, IrType] = {} - self._list_type_map:Dict[str, IrType] = {} - self._optional_type_map:Dict[str, IrType] = {} + self._list_type_map: Dict[str, IrType] = {} + self._optional_type_map: Dict[str, IrType] = {} self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {} def tensor_element_type(self, elem_type: int) -> IrType: @@ -491,7 +499,6 @@ class ContextCache: self._list_type_map[key] = t return t - def get_optional_type(self, element_type: IrType) -> IrType: key = str(element_type) t = self._optional_type_map.get(key) @@ -506,7 +513,6 @@ class ContextCache: self._optional_type_map[key] = t return t - def get_list_element_type(self, tp: onnx.TypeProto) -> IrType: tt = tp.tensor_type if tt.elem_type: @@ -517,8 +523,7 @@ class ContextCache: shape_asm = ",".join("?" if d is None else str(d) for d in dims) return f"vtensor<[{shape_asm}],{element_type}>" - raise OnnxImportError( - f"Unsupport list element type") + raise OnnxImportError(f"Unsupport list element type") def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType: st = tp.sequence_type @@ -535,8 +540,7 @@ class ContextCache: element_type = self.get_list_element_type(st.elem_type) return f"list<{element_type}>" - raise OnnxImportError( - f"Unsupport optional element type") + raise OnnxImportError(f"Unsupport optional element type") def get_vtensor_type( self, dims: Tuple[Optional[int]], element_type: IrType @@ -566,10 +570,7 @@ class ContextCache: try: return RankedTensorType.get(tuple(tp.dims), element_type) except TypeError as e: - raise OnnxImportError( - f"Unsupported builtin tensor type" - ) from e - + raise OnnxImportError(f"Unsupported builtin tensor type") from e def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: if tp == "": @@ -602,9 +603,9 @@ class ContextCache: raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") def _sanitize_name(self, name): - if not name.isidentifier(): - name = "_" + name - return re.sub("[:/]", "_", name) + if not name.isidentifier(): + name = "_" + name + return re.sub("[:/]", "_", name) def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: tensor_type = self.tensor_proto_to_builtin_type(tp) @@ -654,9 +655,13 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = { RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0]) ), onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat( - RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get( - IntegerType.get_signed(64), int.from_bytes(tp.raw_data, "little", - signed=True) if tp.HasField("raw_data") else tp.int64_data[0]) + RankedTensorType.get(shape, IntegerType.get_signed(64)), + IntegerAttr.get( + IntegerType.get_signed(64), + int.from_bytes(tp.raw_data, "little", signed=True) + if tp.HasField("raw_data") + else tp.int64_data[0], + ), ), # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB } @@ -669,8 +674,12 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = { np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False ), onnx.TensorProto.DataType.BOOL: lambda tp: DenseElementsAttr.get( - np.packbits(np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims), - axis=None, bitorder="little"), signless=False + np.packbits( + np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims), + axis=None, + bitorder="little", + ), + signless=False, ), onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get( np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False @@ -758,7 +767,9 @@ ATTRIBUTE_TYPE_HANDLERS = { } -def _get_attr(node: onnx.NodeProto, attr_name: str, is_required: bool = True) -> onnx.AttributeProto: +def _get_attr( + node: onnx.NodeProto, attr_name: str, is_required: bool = True +) -> onnx.AttributeProto: for attr in node.attribute: if attr.name == attr_name: return attr diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 547fe5339..6fbabb09a 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -10,7 +10,7 @@ Typically, when installed from a wheel, this can be invoked as: torch-mlir-import-onnx some.pb Or from Python: - + python -m torch_mlir.tools.import_onnx ... """ import argparse diff --git a/setup.py b/setup.py index d5e3d055c..6f5f5d5d1 100644 --- a/setup.py +++ b/setup.py @@ -59,10 +59,13 @@ from setuptools.command.build_py import build_py if "develop" in sys.argv: print("Warning: The setup.py script is only used for building the wheel package.") - print("For initializing the development environment," - "please use the cmake commands introduced in the docs/development.md.") + print( + "For initializing the development environment," + "please use the cmake commands introduced in the docs/development.md." + ) sys.exit(1) + def _check_env_flag(name: str, default=None) -> bool: return str(os.getenv(name, default)).upper() in ["ON", "1", "YES", "TRUE", "Y"] @@ -87,7 +90,6 @@ MAX_JOBS = os.getenv("MAX_JOBS", str(multiprocessing.cpu_count())) # Build phase discovery is unreliable. Just tell it what phases to run. class CustomBuild(_build): - def initialize_options(self): _build.initialize_options(self) # Make setuptools not steal the build directory name, @@ -102,7 +104,6 @@ class CustomBuild(_build): class CMakeBuild(build_py): - def cmake_build(self, cmake_build_dir): llvm_dir = str(SRC_DIR / "externals" / "llvm-project" / "llvm") @@ -199,14 +200,12 @@ class CMakeBuild(build_py): class CMakeExtension(Extension): - def __init__(self, name, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) class NoopBuildExtension(build_ext): - def build_extension(self, ext): pass @@ -229,13 +228,18 @@ NAME = "torch-mlir-core" # If building PyTorch extensions, customize. if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: import torch + NAME = "torch-mlir" - INSTALL_REQUIRES.extend([ - f"torch=={torch.__version__}".split("+", 1)[0], - ]) - EXT_MODULES.extend([ - CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), - ]) + INSTALL_REQUIRES.extend( + [ + f"torch=={torch.__version__}".split("+", 1)[0], + ] + ) + EXT_MODULES.extend( + [ + CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), + ] + ) setup( diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 4608dfb6c..35d5558f8 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -19,62 +19,76 @@ from lit.llvm.subst import FindTool # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'TORCH_MLIR' +config.name = "TORCH_MLIR" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.py', '.runlit'] +config.suffixes = [".mlir", ".py", ".runlit"] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test') +config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") -config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) -llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) -#llvm_config.use_default_substitutions() +# llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. config.excludes = [ - 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt', - 'lit.cfg.py', 'lit.site.cfg.py' + "Inputs", + "Examples", + "CMakeLists.txt", + "README.txt", + "LICENSE.txt", + "lit.cfg.py", + "lit.site.cfg.py", ] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test') -config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin') +config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") +config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin") # Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) +llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) # Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree. -llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True) +llvm_config.with_environment( + "PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True +) # On Windows the path to python could contains spaces in which case it needs to # be provided in quotes. This is the equivalent of how %python is setup in # llvm/utils/lit/lit/llvm/config.py. if "Windows" in config.host_os: - config.python_executable = '"%s"' % (config.python_executable) + config.python_executable = '"%s"' % (config.python_executable) -tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir, config.torch_mlir_obj_root] +tool_dirs = [ + config.standalone_tools_dir, + config.llvm_tools_dir, + config.torch_mlir_obj_root, +] tools = [ - 'torch-mlir-opt', - ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), + "torch-mlir-opt", + ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"), ] llvm_config.add_tool_substitutions(tools, tool_dirs) if config.enable_bindings_python: - llvm_config.with_environment('PYTHONPATH', [ - os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'), - ], - append_path=True) + llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join(config.torch_mlir_python_packages_dir, "torch_mlir"), + ], + append_path=True, + ) diff --git a/test/python/compile.py b/test/python/compile.py index 990738085..32b47a254 100644 --- a/test/python/compile.py +++ b/test/python/compile.py @@ -26,9 +26,13 @@ class TinyModel(torch.nn.Module): # CHECK-LABEL: TEST: test_enable_ir_printing @run_test def test_enable_ir_printing(): - torchscript.compile(TinyModel(), - torch.ones(1, 3, 20, 20), - output_type="linalg-on-tensors", - enable_ir_printing=True) + torchscript.compile( + TinyModel(), + torch.ones(1, 3, 20, 20), + output_type="linalg-on-tensors", + enable_ir_printing=True, + ) + + # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) # CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} { diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 1ac12296b..08ef9fdc9 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -11,7 +11,11 @@ import torch import torch.nn as nn from torch.export import Dim from torch._dynamo.backends.common import aot_autograd -from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name +from torch._functorch.aot_autograd import ( + make_boxed_compiler, + get_aot_graph_name, + set_model_name, +) from torch_mlir import fx from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -81,6 +85,7 @@ def test_import_frozen_exported_program_with_func_name(): m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net") print(m) + @run # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> @@ -94,18 +99,22 @@ def test_import_frozen_exported_program_with_dynamic_shapes(): batch = Dim("batch") dynamic_shapes = {"x": {0: batch}} - m = fx.export_and_import(Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net") + m = fx.export_and_import( + Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net" + ) print(m) - @make_boxed_compiler -def fx_import_aot_autograd_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): +def fx_import_aot_autograd_backend( + gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] +): print(gm.print_readable(False), flush=True) m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name()) print(m, flush=True) return gm + @run # CHECK-LABEL: test_stateless_fx_import # CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> @@ -114,6 +123,7 @@ def fx_import_aot_autograd_backend(gm: torch.fx.GraphModule, example_inputs: Lis def test_stateless_fx_import(): fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend) set_model_name("basic_forward") + @torch._dynamo.optimize(backend=fx_import_backend) def basic_forward(x): return torch.tanh(x) @@ -130,8 +140,14 @@ def test_full(): super().__init__() def forward(self): - return torch.full([], False, dtype=torch.bool, layout=torch.strided, device='cpu', - pin_memory=False) + return torch.full( + [], + False, + dtype=torch.bool, + layout=torch.strided, + device="cpu", + pin_memory=False, + ) m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True) run_pipeline_with_repro_report( diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 23b573211..7e94a28a7 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -159,7 +159,7 @@ def sparse_jit(f, *args, **kwargs): params_flat, params_spec = torch.utils._pytree.tree_flatten(params) for p in params_flat: if len(p.shape) > 0: - xargs.append(p.numpy()) + xargs.append(p.numpy()) # Prepare input parameters. Sparse input tensors are split into # their composite tensors. All PyTorch tensors are converted # to their backing numpy arrays. Note that the output consists diff --git a/test/python/onnx_importer/_torch_mlir_config.py b/test/python/onnx_importer/_torch_mlir_config.py index f597b63b4..cdab7fbdc 100644 --- a/test/python/onnx_importer/_torch_mlir_config.py +++ b/test/python/onnx_importer/_torch_mlir_config.py @@ -14,6 +14,8 @@ projects (i.e. by just providing this file on the side). from torch_mlir import ir from torch_mlir.extras import onnx_importer + def configure_context(context): from torch_mlir.dialects import torch as torch_d + torch_d.register_dialect(context) diff --git a/test/python/onnx_importer/command_line_test.py b/test/python/onnx_importer/command_line_test.py index 32dc0cbeb..998e10980 100644 --- a/test/python/onnx_importer/command_line_test.py +++ b/test/python/onnx_importer/command_line_test.py @@ -24,9 +24,7 @@ from torch_mlir.tools.import_onnx import __main__ import numpy from onnx import numpy_helper, TensorProto -from onnx.helper import ( - make_model, make_node, make_graph, - make_tensor_value_info) +from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info from onnx.external_data_helper import convert_model_to_external_data from onnx.checker import check_model @@ -46,14 +44,22 @@ def const_model() -> onnx.ModelProto: # Note: data_path must be relative to model_file const = make_node( - 'Constant', [], ['c_shape'], 'const', - value=numpy_helper.from_array(numpy.array([4], dtype=numpy.int64))) + "Constant", + [], + ["c_shape"], + "const", + value=numpy_helper.from_array(numpy.array([4], dtype=numpy.int64)), + ) cofshape = make_node( - 'ConstantOfShape', ['c_shape'], ['c_out'], 'cofshape', - value=numpy_helper.from_array(numpy.array([1], dtype=numpy.int64))) + "ConstantOfShape", + ["c_shape"], + ["c_out"], + "cofshape", + value=numpy_helper.from_array(numpy.array([1], dtype=numpy.int64)), + ) - outval = make_tensor_value_info('c_out', TensorProto.INT64, [None]) - graph = make_graph([const, cofshape], 'constgraph', [], [outval]) + outval = make_tensor_value_info("c_out", TensorProto.INT64, [None]) + graph = make_graph([const, cofshape], "constgraph", [], [outval]) onnx_model = make_model(graph) check_model(onnx_model) @@ -65,26 +71,23 @@ def linear_model() -> onnx.ModelProto: k_dim = 32 value = numpy.arange(k_dim).reshape([k_dim, 1]) value = numpy.asarray(value, dtype=numpy.float32) - A = numpy_helper.from_array(value, name='A') + A = numpy_helper.from_array(value, name="A") value = numpy.array([0.4], dtype=numpy.float32).reshape([1, 1]) - C = numpy_helper.from_array(value, name='C') + C = numpy_helper.from_array(value, name="C") # the part which does not change - X = make_tensor_value_info('X', TensorProto.FLOAT, [1, k_dim]) - Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None, None]) - node1 = make_node('MatMul', ['X', 'A'], ['AX']) - node2 = make_node('Add', ['AX', 'C'], ['Y']) - graph = make_graph([node1, node2], 'lr', [X], [Y], [A, C]) + X = make_tensor_value_info("X", TensorProto.FLOAT, [1, k_dim]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) + node1 = make_node("MatMul", ["X", "A"], ["AX"]) + node2 = make_node("Add", ["AX", "C"], ["Y"]) + graph = make_graph([node1, node2], "lr", [X], [Y], [A, C]) onnx_model = make_model(graph) check_model(onnx_model) return onnx_model -ALL_MODELS = [ - const_model, - linear_model -] +ALL_MODELS = [const_model, linear_model] class CommandLineTest(unittest.TestCase): @@ -104,8 +107,7 @@ class CommandLineTest(unittest.TestCase): model_file = run_path / f"{model_name}-i.onnx" mlir_file = run_path / f"{model_name}-i.torch.mlir" onnx.save(onnx_model, model_file) - args = __main__.parse_arguments([ - str(model_file), "-o", str(mlir_file)]) + args = __main__.parse_arguments([str(model_file), "-o", str(mlir_file)]) __main__.main(args) def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str): @@ -116,16 +118,27 @@ class CommandLineTest(unittest.TestCase): model_data_dir = run_path / data_dir_name model_data_dir.mkdir(exist_ok=True) convert_model_to_external_data( - onnx_model, all_tensors_to_one_file=True, + onnx_model, + all_tensors_to_one_file=True, location=data_dir_name + "/data.bin", size_threshold=48, - convert_attribute=True) + convert_attribute=True, + ) onnx.save(onnx_model, model_file) temp_dir = run_path / "temp" temp_dir.mkdir(exist_ok=True) - args = __main__.parse_arguments([ - str(model_file), "-o", str(mlir_file), "--keep-temps", "--temp-dir", - str(temp_dir), "--data-dir", str(run_path)]) + args = __main__.parse_arguments( + [ + str(model_file), + "-o", + str(mlir_file), + "--keep-temps", + "--temp-dir", + str(temp_dir), + "--data-dir", + str(run_path), + ] + ) __main__.main(args) def test_all(self): diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index bd687ae37..bddd55e87 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -48,6 +48,7 @@ TEST_CAST_XFAILS = [ "node_test_if_opt_model", ] + class ImportSmokeTest(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/utils/bazel/overlay_directories.py b/utils/bazel/overlay_directories.py index c78d075a9..df084c6fa 100644 --- a/utils/bazel/overlay_directories.py +++ b/utils/bazel/overlay_directories.py @@ -21,76 +21,83 @@ import sys def _check_python_version(): - if sys.version_info[0] < 3: - raise RuntimeError( - "Must be invoked with a python 3 interpreter but was %s" % - sys.executable) + if sys.version_info[0] < 3: + raise RuntimeError( + "Must be invoked with a python 3 interpreter but was %s" % sys.executable + ) def _check_dir_exists(path): - if not os.path.isdir(path): - raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), path) + if not os.path.isdir(path): + raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), path) def parse_arguments(): - parser = argparse.ArgumentParser(description=""" + parser = argparse.ArgumentParser( + description=""" Overlays two directories into a target directory using symlinks. Tries to minimize the number of symlinks created (that is, does not symlink every single file). Symlinks every file in the overlay directory. Only symlinks individual files in the source directory if their parent directory is also contained in the overlay directory tree. - """) - parser.add_argument( - "--src", - required=True, - help="Directory that contains most of the content to symlink.") - parser.add_argument( - "--overlay", - required=True, - help="Directory to overlay on top of the source directory.") - parser.add_argument( - "--target", - required=True, - help="Directory in which to place the fused symlink directories.") + """ + ) + parser.add_argument( + "--src", + required=True, + help="Directory that contains most of the content to symlink.", + ) + parser.add_argument( + "--overlay", + required=True, + help="Directory to overlay on top of the source directory.", + ) + parser.add_argument( + "--target", + required=True, + help="Directory in which to place the fused symlink directories.", + ) - args = parser.parse_args() + args = parser.parse_args() - _check_dir_exists(args.target) - _check_dir_exists(args.overlay) - _check_dir_exists(args.src) + _check_dir_exists(args.target) + _check_dir_exists(args.overlay) + _check_dir_exists(args.src) - return args + return args def _symlink_abs(from_path, to_path): - if not os.path.exists(to_path): - os.symlink(os.path.abspath(from_path), os.path.abspath(to_path)) + if not os.path.exists(to_path): + os.symlink(os.path.abspath(from_path), os.path.abspath(to_path)) def main(args): - for root, dirs, files in os.walk(args.overlay): - # We could do something more intelligent here and only symlink individual - # files if the directory is present in both overlay and src. This could also - # be generalized to an arbitrary number of directories without any - # "src/overlay" distinction. In the current use case we only have two and - # the overlay directory is always small, so putting that off for now. - rel_root = os.path.relpath(root, start=args.overlay) - if rel_root != ".": - os.mkdir(os.path.join(args.target, rel_root)) + for root, dirs, files in os.walk(args.overlay): + # We could do something more intelligent here and only symlink individual + # files if the directory is present in both overlay and src. This could also + # be generalized to an arbitrary number of directories without any + # "src/overlay" distinction. In the current use case we only have two and + # the overlay directory is always small, so putting that off for now. + rel_root = os.path.relpath(root, start=args.overlay) + if rel_root != ".": + os.mkdir(os.path.join(args.target, rel_root)) - for file in files: - relpath = os.path.join(rel_root, file) - _symlink_abs(os.path.join(args.overlay, relpath), - os.path.join(args.target, relpath)) + for file in files: + relpath = os.path.join(rel_root, file) + _symlink_abs( + os.path.join(args.overlay, relpath), os.path.join(args.target, relpath) + ) - for src_entry in os.listdir(os.path.join(args.src, rel_root)): - if src_entry not in dirs: - relpath = os.path.join(rel_root, src_entry) - _symlink_abs(os.path.join(args.src, relpath), - os.path.join(args.target, relpath)) + for src_entry in os.listdir(os.path.join(args.src, rel_root)): + if src_entry not in dirs: + relpath = os.path.join(rel_root, src_entry) + _symlink_abs( + os.path.join(args.src, relpath), os.path.join(args.target, relpath) + ) if __name__ == "__main__": - _check_python_version() - main(parse_arguments()) + _check_python_version() + main(parse_arguments())