mirror of https://github.com/llvm/torch-mlir
[NFC reformat] Applies pre-commit formatting to Python files. (#3244)
This is a large change because prior to this point, Python files in the project were not consistently formatted. This reformats them all with black defaults. Based on experience with prior projects, if you have a dev/long-term branch with Python patches, you can minimize merge conflicts prior to rebasing to include this commit by running `black` on your modified Python files, squashing, and then rebasing/merging.pull/3245/head
parent
5d4b803914
commit
6877302504
|
@ -30,6 +30,7 @@ if not TORCH_INCLUDE_DIR.is_dir():
|
|||
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
||||
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def reindent(text, prefix=""):
|
||||
return indent(dedent(text), prefix)
|
||||
|
||||
|
@ -75,7 +76,11 @@ class GenMlirLazyIr(torchgen.dest.GenLazyIR):
|
|||
)
|
||||
|
||||
# Only create this variable if it's used to avoid Wunused-variable
|
||||
operand_idx_counter = "size_t i = 0;" if "i++" in (emplace_arguments_str + emplace_kwarguments) else ""
|
||||
operand_idx_counter = (
|
||||
"size_t i = 0;"
|
||||
if "i++" in (emplace_arguments_str + emplace_kwarguments)
|
||||
else ""
|
||||
)
|
||||
|
||||
return reindent(
|
||||
f"""
|
||||
|
@ -111,12 +116,16 @@ class GenTorchMlirLTC:
|
|||
)
|
||||
assert self.torch_ops_file.exists()
|
||||
self.binary_dir = Path(binary_dir)
|
||||
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
|
||||
assert (
|
||||
self.binary_dir.is_dir()
|
||||
), f"Binary directory not found: {self.binary_dir}"
|
||||
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
||||
self.backend_path = TORCH_MLIR_DIR.joinpath(
|
||||
"projects", "ltc", "csrc", "base_lazy_backend"
|
||||
)
|
||||
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
|
||||
assert (
|
||||
self.backend_path.is_dir()
|
||||
), f"Backend path not found: {self.backend_path}"
|
||||
self.generated_path = self.binary_dir.joinpath(
|
||||
"projects", "ltc", "csrc", "base_lazy_backend", "generated"
|
||||
)
|
||||
|
@ -168,8 +177,9 @@ class GenTorchMlirLTC:
|
|||
if ts_native_yaml_path.exists():
|
||||
ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader)
|
||||
else:
|
||||
logging.warning(f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}")
|
||||
|
||||
logging.warning(
|
||||
f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}"
|
||||
)
|
||||
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||
self.native_functions = parsed_yaml.native_functions
|
||||
|
|
|
@ -9,19 +9,20 @@ import requests
|
|||
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('owner', type=str)
|
||||
parser.add_argument('repo', type=str)
|
||||
parser.add_argument("owner", type=str)
|
||||
parser.add_argument("repo", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get releases
|
||||
response = requests.get(
|
||||
f"https://api.github.com/repos/{args.owner}/{args.repo}/releases")
|
||||
f"https://api.github.com/repos/{args.owner}/{args.repo}/releases"
|
||||
)
|
||||
body = json.loads(response.content)
|
||||
|
||||
# Parse releases
|
||||
releases = []
|
||||
for row in body:
|
||||
for asset in row['assets']:
|
||||
for asset in row["assets"]:
|
||||
releases.append((asset["name"], asset["browser_download_url"]))
|
||||
|
||||
# Output HTML
|
||||
|
|
|
@ -25,10 +25,18 @@ from torch_mlir_e2e_test.configs import (
|
|||
FxImporterTestConfig,
|
||||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import LinalgOnTensorsOnnxBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import (
|
||||
LinalgOnTensorsOnnxBackend,
|
||||
)
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import (
|
||||
LinalgOnTensorsTosaBackend,
|
||||
)
|
||||
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import (
|
||||
LinalgOnTensorsStablehloBackend,
|
||||
)
|
||||
|
||||
from .xfail_sets import (
|
||||
LINALG_XFAIL_SET,
|
||||
|
@ -51,13 +59,28 @@ from .xfail_sets import (
|
|||
|
||||
# Import tests to register them in the global registry.
|
||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||
|
||||
register_all_tests()
|
||||
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core",
|
||||
"torchdynamo", "onnx", "fx_importer", "fx_importer_stablehlo"]
|
||||
config_choices = [
|
||||
"native_torch",
|
||||
"torchscript",
|
||||
"linalg",
|
||||
"stablehlo",
|
||||
"make_fx_tosa",
|
||||
"tosa",
|
||||
"lazy_tensor_core",
|
||||
"torchdynamo",
|
||||
"onnx",
|
||||
"fx_importer",
|
||||
"fx_importer_stablehlo",
|
||||
]
|
||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||
parser.add_argument("-c", "--config",
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
choices=config_choices,
|
||||
default="linalg",
|
||||
help=f"""
|
||||
|
@ -72,34 +95,52 @@ Meaning of options:
|
|||
"onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path.
|
||||
"fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors.
|
||||
"fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend.
|
||||
""")
|
||||
parser.add_argument("-f", "--filter", default=".*", help="""
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default=".*",
|
||||
help="""
|
||||
Regular expression specifying which tests to include in this run.
|
||||
""")
|
||||
parser.add_argument("-v", "--verbose",
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="report test results with additional detail")
|
||||
parser.add_argument("-s", "--sequential",
|
||||
help="report test results with additional detail",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequential",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="""Run tests sequentially rather than in parallel.
|
||||
This can be useful for debugging, since it runs the tests in the same process,
|
||||
which make it easier to attach a debugger or get a stack trace.""")
|
||||
parser.add_argument("--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed",
|
||||
metavar="TEST", type=str, nargs="+",
|
||||
help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.")
|
||||
parser.add_argument("--ignore_failures",
|
||||
which make it easier to attach a debugger or get a stack trace.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed",
|
||||
metavar="TEST",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_failures",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="return exit code 0 even if the test fails to unblock pipeline")
|
||||
help="return exit code 0 even if the test fails to unblock pipeline",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = _get_argparse().parse_args()
|
||||
|
||||
all_test_unique_names = set(
|
||||
test.unique_name for test in GLOBAL_TEST_REGISTRY)
|
||||
all_test_unique_names = set(test.unique_name for test in GLOBAL_TEST_REGISTRY)
|
||||
|
||||
# Find the selected config.
|
||||
if args.config == "linalg":
|
||||
|
@ -147,23 +188,26 @@ def main():
|
|||
xfail_set = ONNX_XFAIL_SET
|
||||
crashing_set = ONNX_CRASHING_SET
|
||||
|
||||
do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set)
|
||||
available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt]
|
||||
do_not_attempt = set(
|
||||
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []
|
||||
).union(crashing_set)
|
||||
available_tests = [
|
||||
test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt
|
||||
]
|
||||
if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None:
|
||||
for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed:
|
||||
if arg not in all_test_unique_names:
|
||||
print(f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name")
|
||||
print(
|
||||
f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Find the selected tests, and emit a diagnostic if none are found.
|
||||
tests = [
|
||||
test for test in available_tests
|
||||
if re.match(args.filter, test.unique_name)
|
||||
test for test in available_tests if re.match(args.filter, test.unique_name)
|
||||
]
|
||||
if len(tests) == 0:
|
||||
print(
|
||||
f"ERROR: the provided filter {args.filter!r} does not match any tests"
|
||||
)
|
||||
print(f"ERROR: the provided filter {args.filter!r} does not match any tests")
|
||||
print("The available tests are:")
|
||||
for test in available_tests:
|
||||
print(test.unique_name)
|
||||
|
@ -175,18 +219,25 @@ def main():
|
|||
# Report the test results.
|
||||
failed = report_results(results, xfail_set, args.verbose, args.config)
|
||||
if args.config == "torchdynamo":
|
||||
print("\033[91mWarning: the TorchScript based dynamo support is deprecated. "
|
||||
"The config for torchdynamo is planned to be removed in the future.\033[0m")
|
||||
print(
|
||||
"\033[91mWarning: the TorchScript based dynamo support is deprecated. "
|
||||
"The config for torchdynamo is planned to be removed in the future.\033[0m"
|
||||
)
|
||||
if args.ignore_failures:
|
||||
sys.exit(0)
|
||||
sys.exit(1 if failed else 0)
|
||||
|
||||
|
||||
def _suppress_warnings():
|
||||
import warnings
|
||||
|
||||
# Ignore warning due to Python bug:
|
||||
# https://stackoverflow.com/questions/4964101/pep-3118-warning-when-using-ctypes-array-as-numpy-array
|
||||
warnings.filterwarnings("ignore",
|
||||
message="A builtin ctypes object gave a PEP3118 format string that does not match its itemsize")
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="A builtin ctypes object gave a PEP3118 format string that does not match its itemsize",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_suppress_warnings()
|
||||
|
|
|
@ -31,21 +31,17 @@ LINALG_CRASHING_SET = {
|
|||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
#### General TorchDynamo/PyTorch errors
|
||||
|
||||
# torch._dynamo.exc.Unsupported: Tensor.item
|
||||
"CumsumModule_basic",
|
||||
|
||||
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
|
||||
# RuntimeError: Failed running call_function aten.convolution_backward(...
|
||||
# https://github.com/pytorch/pytorch/issues/89629
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
|
||||
# Size result mismatch (exposed by downstream canonicalizer
|
||||
# on incompatabile casts).
|
||||
# https://github.com/pytorch/pytorch/issues/119407
|
||||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
|
||||
# RuntimeError: Index tensor must have the same number of dimensions as self tensor
|
||||
# RuntimeError: Failed running call_function aten.nll_loss_backward(...
|
||||
# https://github.com/pytorch/pytorch/issues/89630
|
||||
|
@ -59,196 +55,159 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# RuntimeError: Failed running call_function aten.uniform(...
|
||||
# https://github.com/pytorch/torchdynamo/issues/1954
|
||||
"UniformNoCorrelationModule_basic",
|
||||
|
||||
#### Torch-MLIR internal compiler errors
|
||||
|
||||
# These are probably due to slightly different ops being recorded by
|
||||
# torchdynamo vs. torchscript.
|
||||
|
||||
# No upstream decompositions.
|
||||
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
||||
# See also: https://github.com/pytorch/torchdynamo/issues/327
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
|
||||
# error: unsupported by backend contract: tensor with unknown rank
|
||||
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
|
||||
"ElementwisePreluModule_basic",
|
||||
# error: torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: AssertionError: Unregistered operation: torch.aten._prelu_kernel
|
||||
"ElementwisePreluStaticModule_basic",
|
||||
|
||||
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
|
||||
# ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
#ERROR: value (-56) is not equal to golden value (200)
|
||||
# ERROR: value (-56) is not equal to golden value (200)
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
# ERROR: assert isinstance(e, FakeTensor)
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
# ERROR: assert isinstance(e, FakeTensor)
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
|
||||
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::squeeze.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
|
||||
"PrimsSqueezeModule_basic",
|
||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||
"SplitDimStaticModule_basic",
|
||||
"SplitDimDynamicModule_basic",
|
||||
|
||||
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::view_of.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
|
||||
# See https://github.com/llvm/torch-mlir/pull/2040 and corresponding upstream issue
|
||||
# https://github.com/pytorch/pytorch/issues/99752.
|
||||
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(bool) [TensorVariable()] {}
|
||||
'TensorToBoolZeroRank_basic',
|
||||
'TensorToBool_basic',
|
||||
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
'AtenSubFloatModule_basic',
|
||||
'AtenMulFloatModule_basic',
|
||||
'BoolFloatFalseModule_basic',
|
||||
'BoolFloatTrueModule_basic',
|
||||
'CeilFloatModule_basic',
|
||||
'DivFloatModule_basic',
|
||||
'GeFloatIntModule_basic',
|
||||
'GeFloatModule_basic',
|
||||
'GtFloatIntModule_basic',
|
||||
'NeFloatIntModule_basic',
|
||||
'SubFloatModule_basic',
|
||||
'MulFloatModule_basic',
|
||||
'TensorToFloatZeroRank_basic',
|
||||
'TensorToFloat_basic',
|
||||
"AtenSubFloatModule_basic",
|
||||
"AtenMulFloatModule_basic",
|
||||
"BoolFloatFalseModule_basic",
|
||||
"BoolFloatTrueModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
"TensorToFloat_basic",
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
|
||||
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
'AddIntModule_basic',
|
||||
'AtenIntTensorCharDtypeModule_basic',
|
||||
'BoolIntFalseModule_basic',
|
||||
'BoolIntTrueModule_basic',
|
||||
'DivIntModule_basic',
|
||||
'EqIntModule_basic',
|
||||
'GeIntModule_basic',
|
||||
'GtIntModule_basic',
|
||||
'MulIntModule_basic',
|
||||
'NeIntModule_basic',
|
||||
'SqrtIntModule_basic',
|
||||
'SubIntModule_basic',
|
||||
'TensorToIntZeroRank_basic',
|
||||
'TensorToInt_basic',
|
||||
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
"AddIntModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"EqIntModule_basic",
|
||||
"GeIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"MulIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TensorToIntZeroRank_basic",
|
||||
"TensorToInt_basic",
|
||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: Tensor.item
|
||||
'AtenItemIntOpModule_basic',
|
||||
'AtenItemFpOpModule_basic',
|
||||
|
||||
"AtenItemIntOpModule_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
|
||||
'SortIntListReverse_basic',
|
||||
|
||||
"SortIntListReverse_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
|
||||
'SortIntList_basic',
|
||||
|
||||
"SortIntList_basic",
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
'AtenFloatScalarModule_basic',
|
||||
'AtenIntBoolOpModule_basic',
|
||||
'QuantizedMLP_basic',
|
||||
'QuantizedSingleLayer_basic',
|
||||
'QuantizedBatchedInputSingleLayer_basic',
|
||||
'QuantizedNoLayer_basic',
|
||||
'ScalarImplicitFloatModule_basic',
|
||||
'ScalarImplicitIntModule_basic',
|
||||
"AtenFloatScalarModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"QuantizedBatchedInputSingleLayer_basic",
|
||||
"QuantizedNoLayer_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
||||
'BincountMinlengthModule_basic',
|
||||
'BincountModule_basic',
|
||||
'BincountStaticSizeModule_basic',
|
||||
"BincountMinlengthModule_basic",
|
||||
"BincountModule_basic",
|
||||
"BincountStaticSizeModule_basic",
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
|
||||
'BoolFloatConstantModule_basic',
|
||||
'BoolIntConstantModule_basic',
|
||||
|
||||
"BoolFloatConstantModule_basic",
|
||||
"BoolIntConstantModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
|
||||
'ContainsIntList_False',
|
||||
'ContainsIntList_True',
|
||||
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
|
||||
'AllBoolFalseModule_basic',
|
||||
'AllBoolTrueModule_basic',
|
||||
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
|
||||
'AnyBoolFalseModule_basic',
|
||||
'AnyBoolTrueModule_basic',
|
||||
|
||||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
|
||||
'SqrtIntConstantModule_basic',
|
||||
|
||||
"SqrtIntConstantModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
|
||||
'BroadcastDynamicDimModule_basic',
|
||||
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||
'AtenIntBoolOpConstFalseModule_basic',
|
||||
'AtenIntBoolOpConstTrueModule_basic',
|
||||
'IntFloatModule_basic',
|
||||
'PowIntFloatModule_basic',
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
"IntFloatModule_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
|
||||
'LenStrModule_basic',
|
||||
|
||||
"LenStrModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
|
||||
'NumelModule_basic',
|
||||
'NumelZeroRankModule_basic',
|
||||
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
|
||||
'PrimMaxIntModule_basic',
|
||||
|
||||
"PrimMaxIntModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
|
||||
'PrimMinIntModule_basic',
|
||||
'PrimMinIntDynamicModule_basic',
|
||||
|
||||
"PrimMinIntModule_basic",
|
||||
"PrimMinIntDynamicModule_basic",
|
||||
# START tests failing due to: empty graph in dynamo
|
||||
'IsFloatingPointFloat_True',
|
||||
'IsFloatingPointInt_False',
|
||||
'TorchPrimLoopForLikeModule_basic',
|
||||
'TorchPrimLoopWhileLikeModule_basic',
|
||||
"IsFloatingPointFloat_True",
|
||||
"IsFloatingPointInt_False",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
"TorchPrimLoopWhileLikeModule_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
# END tests failing due to: empty graph in dynamo
|
||||
|
||||
# ERROR due to: backend never runs because of empty frame
|
||||
'ConstantBoolParameterModule_basic',
|
||||
|
||||
"ConstantBoolParameterModule_basic",
|
||||
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"UpSampleNearest2d_basic",
|
||||
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
|
||||
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
|
||||
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
"HBC_basic",
|
||||
|
||||
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"ElementwiseDivScalarModule_basic",
|
||||
|
||||
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
"ElementwiseAtenDivIntScalarModule_basic",
|
||||
|
||||
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
|
||||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||
|
@ -258,57 +217,43 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||
|
||||
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AdaptiveAvgPool2dDynamic_basic",
|
||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||
|
||||
# ERROR: Exception: Unsupported op: get_attr
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"TensorFloatModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
|
||||
# START tests failing due to: complex floating point ops
|
||||
# END tests failing due to: complex floating point ops
|
||||
|
||||
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
"ScatterValueFloatModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
"ScatterValueIntModule_basic",
|
||||
|
||||
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
|
||||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
|
||||
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
|
||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
|
||||
# Lowering not present for this case
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
|
||||
# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
|
||||
"ElementwiseAddScalarInt8Module_basic",
|
||||
|
||||
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
|
||||
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
||||
"ArangeStartOutViewModule_basic",
|
||||
|
||||
# Dynamo does not support tracing quantized tensors
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
|
@ -327,13 +272,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
|
||||
# Dynamo not supporting conv_tbc
|
||||
"ConvTbcModule_basic",
|
||||
|
||||
"FloatImplicitModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
|
||||
# Others
|
||||
"ExponentialModule_basic",
|
||||
"GridSamplerBasic1_basic",
|
||||
|
@ -383,142 +325,141 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
"MaxPool3dModule_basic",
|
||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
|
||||
# Looks like incorrect fx graph conversion
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_XFAIL_SET = {
|
||||
'AllBoolFalseModule_basic',
|
||||
'AllBoolTrueModule_basic',
|
||||
'AnyBoolFalseModule_basic',
|
||||
'AnyBoolTrueModule_basic',
|
||||
'ArangeStartOutViewModule_basic',
|
||||
'AtenEmbeddingBagStaticModule_basic',
|
||||
'AtenEmbeddingBagSumExample_basic',
|
||||
'AtenFloatScalarModule_basic',
|
||||
'AtenIntBoolOpConstFalseModule_basic',
|
||||
'AtenIntBoolOpConstTrueModule_basic',
|
||||
'AtenIntBoolOpModule_basic',
|
||||
'AtenItemFpOpModule_basic',
|
||||
'AtenMatmulQMixedSigni8Transpose_basic',
|
||||
'AtenMatmulQMixedSigni8_basic',
|
||||
'AtenMatmulQint8MV_basic',
|
||||
'AtenMatmulQint8_basic',
|
||||
'AtenMatmulQint8VM_basic',
|
||||
'AtenMatmulQint8VV_basic',
|
||||
'AtenMmQMixedSigni8_basic',
|
||||
'AtenMmQint8_basic',
|
||||
'AtenMmQuint8_basic',
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
"AtenFloatScalarModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||
"AtenMatmulQMixedSigni8_basic",
|
||||
"AtenMatmulQint8MV_basic",
|
||||
"AtenMatmulQint8_basic",
|
||||
"AtenMatmulQint8VM_basic",
|
||||
"AtenMatmulQint8VV_basic",
|
||||
"AtenMmQMixedSigni8_basic",
|
||||
"AtenMmQint8_basic",
|
||||
"AtenMmQuint8_basic",
|
||||
"QuantizedReluInt32_basic",
|
||||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
'AtenSubFloatModule_basic',
|
||||
'BincountMinlengthModule_basic',
|
||||
'BincountModule_basic',
|
||||
'BincountStaticSizeModule_basic',
|
||||
'BoolFloatConstantModule_basic',
|
||||
'BoolFloatFalseModule_basic',
|
||||
'BoolFloatTrueModule_basic',
|
||||
'BoolIntConstantModule_basic',
|
||||
'BoolIntFalseModule_basic',
|
||||
'BoolIntTrueModule_basic',
|
||||
'BroadcastDynamicDimModule_basic',
|
||||
'CeilFloatModule_basic',
|
||||
'ConstantBoolParameterModule_basic',
|
||||
'ContainsIntList_False',
|
||||
'ContainsIntList_True',
|
||||
'Conv2dQInt8Module_basic',
|
||||
'Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier',
|
||||
'ConvTbcModule_basic',
|
||||
'ConvolutionBackwardModule2DPadded_basic',
|
||||
'ConvolutionBackwardModule2DStrided_basic',
|
||||
'ConvolutionBackwardModule2D_basic',
|
||||
'CumsumModule_basic',
|
||||
'DivFloatModule_basic',
|
||||
'DivIntModule_basic',
|
||||
'ElementwiseAddScalar_NumToTensorFloat_Module_basic',
|
||||
'ElementwiseDequantizePerChannelModule_basic',
|
||||
'ElementwiseDequantizePerTensorModule_basic',
|
||||
'ElementwiseQuantizePerTensorModule_basic',
|
||||
'ElementwiseQuantizePerTensorUIntModule_basic',
|
||||
'ElementwiseToDtypeI64ToUI8Module_basic',
|
||||
'EqIntModule_basic',
|
||||
'FakeQuantizePerTensorAffineDynamicShapeModule_basic',
|
||||
'FakeQuantizePerTensorAffineModule_basic',
|
||||
'FakeQuantizePerTensorAffineRoundToEvenModule_basic',
|
||||
'FloatImplicitModule_basic',
|
||||
'GeFloatIntModule_basic',
|
||||
'GeFloatModule_basic',
|
||||
'GeIntModule_basic',
|
||||
'GtFloatIntModule_basic',
|
||||
'GtIntModule_basic',
|
||||
'IntFloatModule_basic',
|
||||
'IntImplicitModule_basic',
|
||||
'IsFloatingPointFloat_True',
|
||||
'IsFloatingPointInt_False',
|
||||
'LenStrModule_basic',
|
||||
'MaxPool3dCeilModeTrueModule_basic',
|
||||
'MaxPool3dEmptyStrideStaticModule_basic',
|
||||
'MaxPool3dLargeDatadModule_basic',
|
||||
'MaxPool3dModuleRandomSimple_basic',
|
||||
'MaxPool3dModule_basic',
|
||||
'MaxPool3dStaticCeilModeTrueModule_basic',
|
||||
'MaxPool3dStaticModule_basic',
|
||||
'MulFloatModule_basic',
|
||||
'NativeGroupNormBackwardModule_basic',
|
||||
'NeFloatIntModule_basic',
|
||||
'NeIntModule_basic',
|
||||
'NllLossModuleBackward1DMeanWeight_basic',
|
||||
'NllLossModuleBackward1DMean_basic',
|
||||
'NllLossModuleBackward1DSumWeight_basic',
|
||||
'NllLossModuleBackward1DSum_basic',
|
||||
'NllLossModuleBackward1DWeight_basic',
|
||||
'NllLossModuleBackward1D_basic',
|
||||
'NumToTensorFloatModule_basic',
|
||||
'NumToTensorIntModule_basic',
|
||||
'NumelModule_basic',
|
||||
'NumelZeroRankModule_basic',
|
||||
'PowIntFloatModule_basic',
|
||||
'PrimMaxIntModule_basic',
|
||||
'PrimMinIntDynamicModule_basic',
|
||||
'PrimMinIntModule_basic',
|
||||
'PrimsSqueezeEmptyDimensionsModule_basic',
|
||||
'PrimsSqueezeModule_basic',
|
||||
'PrimsViewOfModule_basic',
|
||||
'PrimsViewOfZeroRankModule_basic',
|
||||
'QuantizedBatchedInputSingleLayer_basic',
|
||||
'QuantizedMLP_basic',
|
||||
'QuantizedNoLayer_basic',
|
||||
'QuantizedSingleLayer_basic',
|
||||
'ReduceMaxAlongDimUnsignedInt_basic',
|
||||
'ReduceMinAlongDimUnsignedInt_basic',
|
||||
'RsubInt0d_NumToTensor_Module_basic',
|
||||
'ScalarConstantTupleModule_basic',
|
||||
'ScalarImplicitFloatModule_basic',
|
||||
'SortIntListReverse_basic',
|
||||
'SortIntList_basic',
|
||||
'SplitDimDynamicModule_basic',
|
||||
'SplitDimStaticModule_basic',
|
||||
'SqrtIntConstantModule_basic',
|
||||
'SqrtIntModule_basic',
|
||||
'SubFloatModule_basic',
|
||||
'TModuleRank0_basic',
|
||||
'TensorToBoolZeroRank_basic',
|
||||
'TensorToBool_basic',
|
||||
'TensorToFloatZeroRank_basic',
|
||||
'TensorToFloat_basic',
|
||||
'TestMultipleTensorAndPrimitiveTypesReturn_basic',
|
||||
'ThresholdBackward2dMixedModule_basic',
|
||||
'TorchPrimLoopForLikeModule_basic',
|
||||
'TorchPrimLoopWhileLikeModule_basic',
|
||||
'UnbindIntGetItem_Module_basic',
|
||||
'UnbindIntListUnpack_Module_basic',
|
||||
'UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic',
|
||||
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
'UpSampleNearest2dDynamicFactor_basic',
|
||||
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
'ViewSizeFromOtherTensor_basic',
|
||||
"AtenSubFloatModule_basic",
|
||||
"BincountMinlengthModule_basic",
|
||||
"BincountModule_basic",
|
||||
"BincountStaticSizeModule_basic",
|
||||
"BoolFloatConstantModule_basic",
|
||||
"BoolFloatFalseModule_basic",
|
||||
"BoolFloatTrueModule_basic",
|
||||
"BoolIntConstantModule_basic",
|
||||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
"CumsumModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"EqIntModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"FloatImplicitModule_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
"GeIntModule_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"IntFloatModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
"IsFloatingPointFloat_True",
|
||||
"IsFloatingPointInt_False",
|
||||
"LenStrModule_basic",
|
||||
"MaxPool3dCeilModeTrueModule_basic",
|
||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||
"MaxPool3dLargeDatadModule_basic",
|
||||
"MaxPool3dModuleRandomSimple_basic",
|
||||
"MaxPool3dModule_basic",
|
||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"NllLossModuleBackward1DMeanWeight_basic",
|
||||
"NllLossModuleBackward1DMean_basic",
|
||||
"NllLossModuleBackward1DSumWeight_basic",
|
||||
"NllLossModuleBackward1DSum_basic",
|
||||
"NllLossModuleBackward1DWeight_basic",
|
||||
"NllLossModuleBackward1D_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
"PrimMinIntDynamicModule_basic",
|
||||
"PrimMinIntModule_basic",
|
||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||
"PrimsSqueezeModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
"QuantizedBatchedInputSingleLayer_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"QuantizedNoLayer_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
"SplitDimDynamicModule_basic",
|
||||
"SplitDimStaticModule_basic",
|
||||
"SqrtIntConstantModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"TModuleRank0_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
"TensorToFloat_basic",
|
||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
"TorchPrimLoopWhileLikeModule_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_CRASHING_SET = {
|
||||
|
@ -1941,11 +1882,13 @@ TOSA_PASS_SET = {
|
|||
"LinspaceModule_basic",
|
||||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic"
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
MAKE_FX_TOSA_PASS_SET = (
|
||||
TOSA_PASS_SET
|
||||
| {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
|
@ -1965,19 +1908,17 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
||||
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
}) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
|
||||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
|
||||
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
|
||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
|
@ -1986,18 +1927,14 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
# failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal
|
||||
"AtenEyeModuleInt2D_basic",
|
||||
"AtenEyeMModuleInt2D_basic",
|
||||
|
||||
"Conv2dBiasNoPaddingModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||
"Conv2dWithPaddingModule_basic",
|
||||
|
||||
"AtenInstanceNormModule_basic",
|
||||
|
||||
# failed to legalize operation 'torch.operator'
|
||||
"ElementwisePreluModule_basic",
|
||||
"ElementwisePreluStaticModule_basic",
|
||||
|
||||
# Shape Related failures
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
"ReshapeExpandModule_basic",
|
||||
|
@ -2019,8 +1956,7 @@ LTC_CRASHING_SET = {
|
|||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic"
|
||||
"CollapseAllDimensionsModule_basic",
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic",
|
||||
"CollapseRank1DynamicModule_basic",
|
||||
"CollapseStaticModule_basic",
|
||||
"CollapsePartialDynamicModule_basic",
|
||||
|
@ -2162,7 +2098,6 @@ LTC_XFAIL_SET = {
|
|||
ONNX_XFAIL_SET = {
|
||||
# Failure - cast error
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
|
||||
# Failure - expand multiple dynamic dims
|
||||
"EmbeddingModuleF16_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
|
@ -2174,7 +2109,6 @@ ONNX_XFAIL_SET = {
|
|||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||
"IndexTensorSelectDimModule_basic",
|
||||
|
||||
# Failure - incorrect numerics
|
||||
"AvgPool2dDivisorOverrideModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
|
@ -2211,14 +2145,12 @@ ONNX_XFAIL_SET = {
|
|||
"StdCorrectionLargeInputModule_basic",
|
||||
"TupleModule_basic",
|
||||
"VarCorrectionLargeInputModule_basic",
|
||||
|
||||
# Failure - incorrect shape
|
||||
"ArangeStartOutDtypeModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"MoveDimIntNegativeIndexModule_basic",
|
||||
"ReduceL3NormKeepDimModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
|
||||
# Failure - onnx_export
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||
|
@ -2619,10 +2551,8 @@ ONNX_XFAIL_SET = {
|
|||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||
"_SoftmaxModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.AveragePool
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.If
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
|
@ -2633,12 +2563,10 @@ ONNX_XFAIL_SET = {
|
|||
"DiagonalModule_with_offset",
|
||||
"TileBigDimsSizeModule_basic",
|
||||
"TileSmallDimsSizeModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.MaxPool
|
||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||
"MaxPool2dWithIndicesStaticModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ReduceProd
|
||||
"ReduceProdFloatModule_basic",
|
||||
"ReduceProdDtypeFloatModule_basic",
|
||||
|
@ -2646,7 +2574,6 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceProdUnsignedIntModule_basic",
|
||||
"ReduceProdSignedIntModule_basic",
|
||||
"ReduceProdDtypeIntModule_basic",
|
||||
|
||||
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
|
||||
"RandnDtypeDeviceModule_basic",
|
||||
"RandnGeneratorF64Module_basic",
|
||||
|
@ -2656,21 +2583,17 @@ ONNX_XFAIL_SET = {
|
|||
"BernoulliFloatModule_basic",
|
||||
"BernoulliPModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ReduceProd
|
||||
"ReduceProdDimIntFloatModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Resize
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ScatterElements
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMinModuleIncludeSelf",
|
||||
"ScatterReduceIntMaxModuleIncludeSelf",
|
||||
"ScatterReduceIntMinModuleIncludeSelf",
|
||||
"ScatterValueFloatModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ScatterND
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
|
@ -2696,14 +2619,11 @@ ONNX_XFAIL_SET = {
|
|||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
|
||||
"CrossEntropyLossModule_basic",
|
||||
"CrossEntropyLossNoReductionModule_basic",
|
||||
|
||||
# RuntimeError: unsupported input type: Device
|
||||
"PrimsIotaModule_basic",
|
||||
|
||||
# Failure - unknown
|
||||
"BernoulliModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
|
@ -2741,7 +2661,7 @@ if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
|
|||
"ReduceL1NormWithDTypeModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse('2.3.0.dev'):
|
||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
|
||||
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
|
|
|
@ -23,34 +23,43 @@ import torch._lazy
|
|||
from datasets import load_dataset
|
||||
from datasets.dataset_dict import DatasetDict
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BertForSequenceClassification, \
|
||||
BertConfig, BertTokenizer, AdamW, get_scheduler
|
||||
from transformers import (
|
||||
BertForSequenceClassification,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
AdamW,
|
||||
get_scheduler,
|
||||
)
|
||||
|
||||
|
||||
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], padding="max_length",
|
||||
truncation=True)
|
||||
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
||||
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
|
||||
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
|
||||
tokenized_datasets.set_format('torch')
|
||||
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
|
||||
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
||||
tokenized_datasets.set_format("torch")
|
||||
|
||||
return tokenized_datasets
|
||||
|
||||
|
||||
def train(model: BertForSequenceClassification,
|
||||
def train(
|
||||
model: BertForSequenceClassification,
|
||||
num_epochs: int,
|
||||
num_training_steps: int,
|
||||
train_dataloader: DataLoader,
|
||||
device: torch.device) -> List[torch.Tensor]:
|
||||
device: torch.device,
|
||||
) -> List[torch.Tensor]:
|
||||
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||
lr_scheduler = get_scheduler('linear', optimizer=optimizer,
|
||||
lr_scheduler = get_scheduler(
|
||||
"linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=num_training_steps)
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
model.train()
|
||||
losses = []
|
||||
|
@ -66,14 +75,14 @@ def train(model: BertForSequenceClassification,
|
|||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if 'lazy' in str(model.device):
|
||||
if "lazy" in str(model.device):
|
||||
print("Calling Mark Step")
|
||||
torch._lazy.mark_step()
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def main(device='lazy', full_size=False):
|
||||
def main(device="lazy", full_size=False):
|
||||
"""
|
||||
Load model to specified device. Ensure that any backends have been initialized by this point.
|
||||
|
||||
|
@ -82,15 +91,14 @@ def main(device='lazy', full_size=False):
|
|||
"""
|
||||
torch.manual_seed(0)
|
||||
|
||||
tokenized_datasets = tokenize_dataset(load_dataset('imdb'))
|
||||
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \
|
||||
.select(range(2))
|
||||
tokenized_datasets = tokenize_dataset(load_dataset("imdb"))
|
||||
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2))
|
||||
|
||||
train_dataloader = DataLoader(small_train_dataset, shuffle=True,
|
||||
batch_size=8)
|
||||
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
|
||||
if full_size:
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
|
||||
num_labels=2)
|
||||
model = BertForSequenceClassification.from_pretrained(
|
||||
"bert-base-cased", num_labels=2
|
||||
)
|
||||
else:
|
||||
configuration = BertConfig(
|
||||
vocab_size=28996,
|
||||
|
@ -98,7 +106,7 @@ def main(device='lazy', full_size=False):
|
|||
num_hidden_layers=1,
|
||||
num_attention_heads=2,
|
||||
intermediate_size=32,
|
||||
hidden_act='gelu',
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
|
@ -113,12 +121,12 @@ def main(device='lazy', full_size=False):
|
|||
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
|
||||
|
||||
# Get debug information from LTC
|
||||
if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules:
|
||||
if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules:
|
||||
computation = lazy_backend.get_latest_computation()
|
||||
if computation:
|
||||
print(computation.debug_string())
|
||||
|
||||
print('Loss: ', losses)
|
||||
print("Loss: ", losses)
|
||||
|
||||
return model, losses
|
||||
|
||||
|
@ -136,7 +144,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"-f",
|
||||
"--full_size",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use full sized BERT model instead of one with smaller parameterization",
|
||||
)
|
||||
|
@ -145,6 +153,7 @@ if __name__ == "__main__":
|
|||
if args.device in ("TS", "MLIR_EXAMPLE"):
|
||||
if args.device == "TS":
|
||||
import torch._lazy.ts_backend
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
elif args.device == "MLIR_EXAMPLE":
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch._lazy
|
|||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def main(device='lazy'):
|
||||
def main(device="lazy"):
|
||||
"""
|
||||
Load model to specified device. Ensure that any backends have been initialized by this point.
|
||||
|
||||
|
@ -65,7 +65,7 @@ def main(device='lazy'):
|
|||
torch._lazy.mark_step()
|
||||
|
||||
# Get debug information from LTC
|
||||
if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules:
|
||||
if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules:
|
||||
computation = lazy_backend.get_latest_computation()
|
||||
if computation:
|
||||
print(computation.debug_string())
|
||||
|
@ -90,6 +90,7 @@ if __name__ == "__main__":
|
|||
if args.device in ("TS", "MLIR_EXAMPLE"):
|
||||
if args.device == "TS":
|
||||
import torch._lazy.ts_backend
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
elif args.device == "MLIR_EXAMPLE":
|
||||
|
|
|
@ -21,19 +21,18 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
|||
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
}
|
||||
img = Image.open(requests.get(url, headers=headers,
|
||||
stream=True).raw).convert("RGB")
|
||||
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose([
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
@ -62,17 +61,23 @@ def predictions(torch_func, jit_func, img, labels):
|
|||
print("torch-mlir prediction")
|
||||
print(prediction)
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
def refbackend_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
mlir_module = torchscript.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors")
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(mlir_module)
|
||||
loaded = backend.load(compiled)
|
||||
|
@ -85,10 +90,17 @@ def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
|||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18)
|
||||
|
||||
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), img, labels)
|
||||
predictions(
|
||||
resnet18.forward,
|
||||
lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(),
|
||||
img,
|
||||
labels,
|
||||
)
|
||||
|
|
|
@ -18,19 +18,18 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
|||
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
}
|
||||
img = Image.open(requests.get(url, headers=headers,
|
||||
stream=True).raw).convert("RGB")
|
||||
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose([
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
@ -59,7 +58,10 @@ def predictions(torch_func, jit_func, img, labels):
|
|||
print("torch-mlir prediction")
|
||||
print(prediction)
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
|
@ -67,7 +69,9 @@ labels = load_labels()
|
|||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||
module = torchscript.compile(
|
||||
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
|
||||
)
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
|
|
@ -13,8 +13,12 @@ resnet18.eval()
|
|||
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
module = torchscript.compile(
|
||||
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
|
||||
)
|
||||
print(
|
||||
"LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10)
|
||||
)
|
||||
# TODO: Debug why this is so slow.
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
|
||||
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
|
|
|
@ -4,10 +4,12 @@ from torch_mlir import torchscript
|
|||
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.eval()
|
||||
data = torch.randn(2,3,200,200)
|
||||
data = torch.randn(2, 3, 200, 200)
|
||||
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
||||
|
||||
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False)
|
||||
module = torchscript.compile(
|
||||
model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False
|
||||
)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
|
|
|
@ -7,17 +7,22 @@ from transformers import BertForMaskedLM
|
|||
class BertTinyWrapper(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False)
|
||||
self.bert = BertForMaskedLM.from_pretrained(
|
||||
"prajjwal1/bert-tiny", return_dict=False
|
||||
)
|
||||
|
||||
def forward(self, data):
|
||||
return self.bert(data)[0]
|
||||
|
||||
|
||||
model = BertTinyWrapper()
|
||||
model.eval()
|
||||
data = torch.randint(30522, (2, 128))
|
||||
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
||||
|
||||
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True)
|
||||
module = torchscript.compile(
|
||||
model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True
|
||||
)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
|
|
|
@ -11,18 +11,23 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
from torch_mlir.jit_ir_importer import ClassAnnotator
|
||||
from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4], torch.float32, False),
|
||||
([4, 5], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
|
||||
module = MmModule()
|
||||
annotator = ClassAnnotator()
|
||||
extract_annotations(module, torch.jit.script(module), annotator)
|
||||
|
|
|
@ -26,6 +26,8 @@ print(torchscript.compile(scripted, example_args))
|
|||
scripted = torch.jit.script(BasicModule())
|
||||
try:
|
||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||
torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
|
||||
torchscript.compile(
|
||||
scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
@ -8,10 +8,12 @@
|
|||
import torch
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
class BasicModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.sin(x)
|
||||
|
||||
|
||||
example_arg = torch.ones(2, 3)
|
||||
example_args = torchscript.ExampleArgs.get(example_arg)
|
||||
|
||||
|
@ -23,6 +25,8 @@ print(torchscript.compile(traced, example_args))
|
|||
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||
try:
|
||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||
torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg))
|
||||
torchscript.compile(
|
||||
traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
@ -9,15 +9,24 @@ import torch
|
|||
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
class AddmmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return torch.ops.aten.addmm(x, y, z)
|
||||
|
||||
|
||||
example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]
|
||||
|
||||
print(torchscript.compile(AddmmModule(), example_args,
|
||||
output_type="torch", backend_legal_ops=["aten.addmm"]))
|
||||
print(
|
||||
torchscript.compile(
|
||||
AddmmModule(),
|
||||
example_args,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["aten.addmm"],
|
||||
)
|
||||
)
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.addmm
|
||||
|
|
|
@ -9,12 +9,15 @@ import torch
|
|||
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.tanh(x)
|
||||
|
||||
|
||||
tanh_example_input = torch.ones(2, 3)
|
||||
|
||||
# Simplest case: One example argument.
|
||||
|
@ -35,16 +38,23 @@ print(torchscript.compile(TanhModule(), placeholder))
|
|||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
||||
|
||||
# Basic smoke test for the raw output type.
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW))
|
||||
print(
|
||||
torchscript.compile(
|
||||
TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW
|
||||
)
|
||||
)
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, lhs, rhs ):
|
||||
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.ops.aten.mm(lhs, rhs)
|
||||
|
||||
|
||||
# N > 1 inputs.
|
||||
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
|
||||
print(torchscript.compile(MmModule(), mm_example_inputs))
|
||||
|
@ -52,7 +62,10 @@ print(torchscript.compile(MmModule(), mm_example_inputs))
|
|||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
||||
|
||||
# Mixes Tensor's and TensorPlaceholder's.
|
||||
mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
|
||||
mm_dynamic_inputs = [
|
||||
mm_example_inputs[0],
|
||||
torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1]),
|
||||
]
|
||||
print(torchscript.compile(MmModule(), mm_dynamic_inputs))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||
|
|
|
@ -10,11 +10,19 @@ import torch
|
|||
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
def simple(x):
|
||||
return x * x
|
||||
|
||||
example_input = torch.randn(1,)
|
||||
graph = functorch.make_fx(simple)(torch.randn(1,))
|
||||
|
||||
example_input = torch.randn(
|
||||
1,
|
||||
)
|
||||
graph = functorch.make_fx(simple)(
|
||||
torch.randn(
|
||||
1,
|
||||
)
|
||||
)
|
||||
|
||||
# Simplest case: One example argument.
|
||||
print(torchscript.compile(graph, example_input))
|
||||
|
|
|
@ -9,15 +9,22 @@ import torch
|
|||
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.tanh(x)
|
||||
|
||||
|
||||
tanh_example_input = torch.ones(2, 3)
|
||||
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH))
|
||||
print(
|
||||
torchscript.compile(
|
||||
TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH
|
||||
)
|
||||
)
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch"))
|
||||
|
|
|
@ -14,6 +14,7 @@ class TanhModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.ops.aten.tanh(x)
|
||||
|
||||
|
||||
tanh_example_input = torch.ones(2, 3)
|
||||
|
||||
# Simplest case: One example argument.
|
||||
|
@ -32,10 +33,12 @@ print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True))
|
|||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
|
||||
# TensorPlaceholder support.
|
||||
placeholder = torchscript.TensorPlaceholder.like(
|
||||
tanh_example_input, dynamic_axes=[1])
|
||||
print(torchscript.compile(TanhModule(), [placeholder],
|
||||
use_tracing=True, ignore_traced_shapes=True))
|
||||
placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
|
||||
print(
|
||||
torchscript.compile(
|
||||
TanhModule(), [placeholder], use_tracing=True, ignore_traced_shapes=True
|
||||
)
|
||||
)
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||
|
||||
|
@ -55,18 +58,18 @@ except Exception as e:
|
|||
|
||||
class DictModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x['a'] * 2.0
|
||||
return x["a"] * 2.0
|
||||
|
||||
|
||||
try:
|
||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||
torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
||||
torchscript.compile(DictModule(), {"a": torch.tensor(3.0)}, use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
try:
|
||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||
torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
||||
torchscript.compile(DictModule(), [{"a": torch.tensor(3.0)}], use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
@ -15,8 +15,9 @@ from torch_mlir_e2e_test.debug.lockstep import make_lockstep_debug_backend
|
|||
|
||||
@make_simple_dynamo_backend
|
||||
@make_lockstep_debug_backend()
|
||||
def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
def miscompile_div_as_mul_backend(
|
||||
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# Copy `gm` and rewrite `div` to `mul`.
|
||||
new_g = torch.fx.Graph()
|
||||
new_g.output(new_g.graph_copy(gm.graph, {}))
|
||||
|
@ -41,7 +42,7 @@ def f(x, y):
|
|||
return a, b, c
|
||||
|
||||
|
||||
args = (torch.tensor([1., 2., 3.]), torch.tensor([4., 5., 6.]))
|
||||
args = (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]))
|
||||
try:
|
||||
print(f(*args))
|
||||
except AssertionError as e:
|
||||
|
|
|
@ -11,7 +11,11 @@ import torch
|
|||
import torch.fx
|
||||
import torch._dynamo as dynamo
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_compilation_context, set_model_name
|
||||
from torch._functorch.aot_autograd import (
|
||||
make_boxed_compiler,
|
||||
get_aot_compilation_context,
|
||||
set_model_name,
|
||||
)
|
||||
|
||||
from torch_mlir.compiler_utils import TorchMlirCompilerError
|
||||
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
||||
|
@ -19,8 +23,9 @@ from torch_mlir_e2e_test.configs.torchdynamo import jit
|
|||
|
||||
|
||||
@make_boxed_compiler
|
||||
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
def my_aot_autograd_backend(
|
||||
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
print(gm.graph)
|
||||
*_, model_name, nth_graph = get_aot_compilation_context()
|
||||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||
|
@ -59,9 +64,7 @@ basic(torch.randn(3, 4))
|
|||
# CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16>
|
||||
@dynamo.optimize(my_backend)
|
||||
def literals_list_device_int_none_dtype():
|
||||
return torch.ops.aten.randn([3, 4],
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float16)
|
||||
return torch.ops.aten.randn([3, 4], device=torch.device("cpu"), dtype=torch.float16)
|
||||
|
||||
|
||||
set_model_name("literals_list_device_int_none_dtype")
|
||||
|
|
|
@ -19,23 +19,23 @@ from lit.llvm.subst import FindTool
|
|||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'TORCH_MLIR_PYTHON'
|
||||
config.name = "TORCH_MLIR_PYTHON"
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
if 'TEST_SRC_PATH' in os.environ:
|
||||
config.environment['TEST_SRC_PATH'] = os.environ['TEST_SRC_PATH']
|
||||
if "TEST_SRC_PATH" in os.environ:
|
||||
config.environment["TEST_SRC_PATH"] = os.environ["TEST_SRC_PATH"]
|
||||
|
||||
# path to our python operation library
|
||||
config.environment['TEST_BUILD_PATH'] = os.path.join(config.torch_mlir_obj_root)
|
||||
config.environment["TEST_BUILD_PATH"] = os.path.join(config.torch_mlir_obj_root)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.py']
|
||||
config.suffixes = [".py"]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||
|
||||
# On Windows the path to python could contains spaces in which case it needs to
|
||||
# be provided in quotes. This is the equivalent of how %python is setup in
|
||||
|
@ -43,19 +43,25 @@ config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
|||
if "Windows" in config.host_os:
|
||||
config.python_executable = '"%s"' % (config.python_executable)
|
||||
|
||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||
config.substitutions.append(('%PYTHON', config.python_executable))
|
||||
config.substitutions.append(("%PATH%", config.environment["PATH"]))
|
||||
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
|
||||
config.substitutions.append(("%PYTHON", config.python_executable))
|
||||
|
||||
llvm_config.with_system_environment(
|
||||
['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
|
||||
|
||||
llvm_config.use_default_substitutions()
|
||||
|
||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||
# directories.
|
||||
config.excludes = ['lit.cfg.py', 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
|
||||
config.excludes = [
|
||||
"lit.cfg.py",
|
||||
"Inputs",
|
||||
"Examples",
|
||||
"CMakeLists.txt",
|
||||
"README.txt",
|
||||
"LICENSE.txt",
|
||||
]
|
||||
|
||||
if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))):
|
||||
config.excludes.append("lazy_backend")
|
||||
|
@ -64,20 +70,23 @@ if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))):
|
|||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
config.torch_mlir_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||
config.torch_mlir_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin")
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'),
|
||||
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment(
|
||||
"PYTHONPATH",
|
||||
[
|
||||
os.path.join(config.torch_mlir_python_packages_dir, "torch_mlir"),
|
||||
],
|
||||
append_path=True)
|
||||
append_path=True,
|
||||
)
|
||||
|
||||
|
||||
tool_dirs = [config.torch_mlir_tools_dir, config.llvm_tools_dir]
|
||||
tools = [
|
||||
'torch-mlir-opt',
|
||||
"torch-mlir-opt",
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
|
|
@ -40,5 +40,5 @@ def main():
|
|||
report_results(results, set(), verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -43,5 +43,5 @@ def main():
|
|||
report_results(results, set(), verbose=True, config="myconfig")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -122,7 +122,7 @@ class ErroneousModule(torch.nn.Module):
|
|||
@torch.jit.export
|
||||
def test_tensor_value_mismatch(self):
|
||||
if torch.jit.is_scripting():
|
||||
return torch.tensor([1., 2., 3.])
|
||||
return torch.tensor([1.0, 2.0, 3.0])
|
||||
else:
|
||||
return torch.tensor([1.5, 2.5, 3.5])
|
||||
|
||||
|
@ -132,9 +132,9 @@ class ErroneousModule(torch.nn.Module):
|
|||
@torch.jit.export
|
||||
def test_tensor_shape_mismatch(self):
|
||||
if torch.jit.is_scripting():
|
||||
return torch.tensor([1., 2.])
|
||||
return torch.tensor([1.0, 2.0])
|
||||
else:
|
||||
return torch.tensor([1., 2., 3.])
|
||||
return torch.tensor([1.0, 2.0, 3.0])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ErroneousModule())
|
||||
|
@ -157,5 +157,5 @@ def main():
|
|||
report_results(results, set(), verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -51,5 +51,5 @@ def main():
|
|||
report_results(results, set())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -39,5 +39,5 @@ def main():
|
|||
report_results(results, set(), verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir_e2e_test.reporting import report_results
|
|||
from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class Submodule2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -19,6 +20,7 @@ class Submodule2(torch.nn.Module):
|
|||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
|
||||
class Submodule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -43,5 +45,5 @@ def main():
|
|||
report_results(results, set())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
import pdb
|
||||
|
||||
# This file implements a pure-Python importer from a restricted subset of
|
||||
# FX IR into MLIR.
|
||||
#
|
||||
|
@ -77,8 +78,12 @@ def _verify_fx_graph_conforms_to_subset(g: torch.fx.Graph):
|
|||
if len(node.args) != len(node.target._schema.arguments):
|
||||
assert len(node.args) < len(node.target._schema.arguments)
|
||||
for i, argument in enumerate(
|
||||
node.target._schema.arguments[len(node.args):]):
|
||||
if not argument.has_default_value() and argument.name not in node.kwargs:
|
||||
node.target._schema.arguments[len(node.args) :]
|
||||
):
|
||||
if (
|
||||
not argument.has_default_value()
|
||||
and argument.name not in node.kwargs
|
||||
):
|
||||
raise Exception(
|
||||
f"Unsupported: missing default value for argument {i} in schema for {node.target}"
|
||||
)
|
||||
|
@ -152,12 +157,12 @@ def _convert_dtype_to_mlir_type(dtype: torch.dtype) -> str:
|
|||
if dtype == torch.complex128:
|
||||
return "complex<f64>"
|
||||
|
||||
|
||||
raise Exception(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def _import_fake_tensor_as_mlir_type(
|
||||
fake_tensor: torch._subclasses.FakeTensor) -> ir.Type:
|
||||
fake_tensor: torch._subclasses.FakeTensor,
|
||||
) -> ir.Type:
|
||||
# TODO: Find story for how to get dynamically shaped tensors here.
|
||||
shape = ",".join(str(d) for d in fake_tensor.shape)
|
||||
dtype = _convert_dtype_to_mlir_type(fake_tensor.dtype)
|
||||
|
@ -178,7 +183,8 @@ def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType:
|
|||
if node.op == "output":
|
||||
# TODO(DNS): Test this or add verifier that it can't happen.
|
||||
result_types = torch.fx.map_arg(
|
||||
node.args[0], lambda n: _mlir_types_for_node(n)[0])
|
||||
node.args[0], lambda n: _mlir_types_for_node(n)[0]
|
||||
)
|
||||
# Note: We import directly to the backend contract -- multiple results
|
||||
# are modeled with func.func native multiple results rather than as a
|
||||
# singleton value / tuple.
|
||||
|
@ -191,64 +197,40 @@ def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType:
|
|||
|
||||
DTYPE_TO_INT = {
|
||||
# TODO(DNS): Fill in from AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
|
||||
torch.uint8:
|
||||
0,
|
||||
torch.int8:
|
||||
1,
|
||||
torch.int16:
|
||||
2,
|
||||
torch.int32:
|
||||
3,
|
||||
torch.int64:
|
||||
4,
|
||||
torch.float16:
|
||||
5,
|
||||
torch.float32:
|
||||
6,
|
||||
torch.float64:
|
||||
7,
|
||||
torch.uint8: 0,
|
||||
torch.int8: 1,
|
||||
torch.int16: 2,
|
||||
torch.int32: 3,
|
||||
torch.int64: 4,
|
||||
torch.float16: 5,
|
||||
torch.float32: 6,
|
||||
torch.float64: 7,
|
||||
# torch.complex_half 8
|
||||
torch.complex64:
|
||||
9,
|
||||
torch.complex128:
|
||||
10,
|
||||
torch.bool:
|
||||
11,
|
||||
torch.qint8:
|
||||
12,
|
||||
torch.quint8:
|
||||
13,
|
||||
torch.complex64: 9,
|
||||
torch.complex128: 10,
|
||||
torch.bool: 11,
|
||||
torch.qint8: 12,
|
||||
torch.quint8: 13,
|
||||
# torch.qint32 14
|
||||
torch.bfloat16:
|
||||
15,
|
||||
torch.bfloat16: 15,
|
||||
}
|
||||
|
||||
MEMORY_FORMAT_TO_INT = {
|
||||
# https://github.com/pytorch/pytorch/c10/core/MemoryFormat.h#L28
|
||||
torch.contiguous_format:
|
||||
0,
|
||||
torch.preserve_format:
|
||||
1,
|
||||
torch.channels_last:
|
||||
2,
|
||||
torch.channels_last_3d:
|
||||
3,
|
||||
torch.contiguous_format: 0,
|
||||
torch.preserve_format: 1,
|
||||
torch.channels_last: 2,
|
||||
torch.channels_last_3d: 3,
|
||||
}
|
||||
|
||||
LAYOUT_TO_INT = {
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_layouts.cpp
|
||||
torch.strided:
|
||||
0,
|
||||
torch.sparse_coo:
|
||||
1,
|
||||
torch.sparse_csr:
|
||||
2,
|
||||
torch.sparse_csc:
|
||||
3,
|
||||
torch.sparse_bsr:
|
||||
4,
|
||||
torch.sparse_bsc:
|
||||
5,
|
||||
torch.strided: 0,
|
||||
torch.sparse_coo: 1,
|
||||
torch.sparse_csr: 2,
|
||||
torch.sparse_csc: 3,
|
||||
torch.sparse_bsr: 4,
|
||||
torch.sparse_bsc: 5,
|
||||
}
|
||||
|
||||
|
||||
|
@ -264,7 +246,6 @@ def _mlir_location_for_node(node: torch.fx.Node) -> ir.Location:
|
|||
|
||||
|
||||
class _FXGraphImporter:
|
||||
|
||||
def __init__(self, g: torch.fx.Graph, func_name: str):
|
||||
self._g = g
|
||||
self._func_name = func_name
|
||||
|
@ -277,7 +258,8 @@ class _FXGraphImporter:
|
|||
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
||||
self._module = ir.Module.create(ir.Location.unknown())
|
||||
self._module.operation.attributes[
|
||||
"torch.debug_module_name"] = ir.StringAttr.get(func_name)
|
||||
"torch.debug_module_name"
|
||||
] = ir.StringAttr.get(func_name)
|
||||
function_type = _extract_function_type_from_graph(g)
|
||||
func = func_dialect.FuncOp(
|
||||
func_name,
|
||||
|
@ -285,8 +267,7 @@ class _FXGraphImporter:
|
|||
loc=ir.Location.unknown(), # TODO: Can we do better?
|
||||
ip=ir.InsertionPoint(self._module.body),
|
||||
)
|
||||
self._body_block = ir.Block.create_at_start(func.body,
|
||||
function_type.inputs)
|
||||
self._body_block = ir.Block.create_at_start(func.body, function_type.inputs)
|
||||
|
||||
def import_graph(self) -> ir.Module:
|
||||
with ir.InsertionPoint(self._body_block):
|
||||
|
@ -294,14 +275,15 @@ class _FXGraphImporter:
|
|||
for node in self._g.nodes:
|
||||
with _mlir_location_for_node(node):
|
||||
if node.op == "placeholder":
|
||||
self._env[(
|
||||
node, 0
|
||||
)] = self._body_block.arguments[num_placeholders_seen]
|
||||
self._env[(node, 0)] = self._body_block.arguments[
|
||||
num_placeholders_seen
|
||||
]
|
||||
num_placeholders_seen += 1
|
||||
if node.op == "call_function":
|
||||
if node.target is operator.getitem:
|
||||
self._env[(node, 0)] = self._env[(node.args[0],
|
||||
node.args[1])]
|
||||
self._env[(node, 0)] = self._env[
|
||||
(node.args[0], node.args[1])
|
||||
]
|
||||
else:
|
||||
self._import_op_overload_call(node)
|
||||
if node.op == "output":
|
||||
|
@ -309,9 +291,7 @@ class _FXGraphImporter:
|
|||
# a tuple of return values (without the single-element special
|
||||
# case)
|
||||
# DNS: Test or verify no literals as results.
|
||||
operands = [
|
||||
self._import_argument(arg) for arg in node.args[0]
|
||||
]
|
||||
operands = [self._import_argument(arg) for arg in node.args[0]]
|
||||
func_dialect.ReturnOp(operands)
|
||||
return self._module
|
||||
|
||||
|
@ -328,7 +308,8 @@ class _FXGraphImporter:
|
|||
|
||||
# DNS: Unregistered ops
|
||||
assert ir.Context.current.is_registered_operation(
|
||||
mlir_op_name), f"Unregistered operation: {mlir_op_name}"
|
||||
mlir_op_name
|
||||
), f"Unregistered operation: {mlir_op_name}"
|
||||
|
||||
# Construct the Operation.
|
||||
result_types = _mlir_types_for_node(node)
|
||||
|
@ -352,9 +333,9 @@ class _FXGraphImporter:
|
|||
for i, value in enumerate(operation.results):
|
||||
self._env[(node, i)] = value
|
||||
|
||||
def _import_argument(self,
|
||||
arg: torch.fx.node.Argument,
|
||||
expected_type_for_literal=None) -> ir.Value:
|
||||
def _import_argument(
|
||||
self, arg: torch.fx.node.Argument, expected_type_for_literal=None
|
||||
) -> ir.Value:
|
||||
"""Import an FX `Argument`, which is analogous to an MLIR `Value`.
|
||||
|
||||
Args:
|
||||
|
@ -371,22 +352,21 @@ class _FXGraphImporter:
|
|||
assert expected_type_for_literal is not None
|
||||
return self._import_literal(arg, expected_type_for_literal)
|
||||
|
||||
def _import_literal(self, arg: torch.fx.node.Argument,
|
||||
expected_type) -> ir.Value:
|
||||
def _import_literal(self, arg: torch.fx.node.Argument, expected_type) -> ir.Value:
|
||||
if arg is None:
|
||||
return torch_dialect.ConstantNoneOp().result
|
||||
if isinstance(expected_type, torch.OptionalType):
|
||||
return self._import_argument(arg, expected_type.getElementType())
|
||||
if isinstance(arg, bool):
|
||||
return torch_dialect.ConstantBoolOp(
|
||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)).result
|
||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)
|
||||
).result
|
||||
if isinstance(arg, int):
|
||||
return torch_dialect.ConstantIntOp(
|
||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(64),
|
||||
arg)).result
|
||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(64), arg)
|
||||
).result
|
||||
if isinstance(arg, float):
|
||||
return torch_dialect.ConstantFloatOp(
|
||||
ir.FloatAttr.get_f64(arg)).result
|
||||
return torch_dialect.ConstantFloatOp(ir.FloatAttr.get_f64(arg)).result
|
||||
if isinstance(arg, str):
|
||||
return torch_dialect.ConstantStrOp(ir.StringAttr.get(arg)).result
|
||||
if isinstance(arg, torch.dtype):
|
||||
|
@ -394,12 +374,10 @@ class _FXGraphImporter:
|
|||
return self._import_argument(DTYPE_TO_INT[arg], expected_type)
|
||||
if isinstance(arg, torch.device):
|
||||
# TODO(DNS): Device index? arg.index
|
||||
return torch_dialect.ConstantDeviceOp(ir.StringAttr.get(
|
||||
arg.type)).result
|
||||
return torch_dialect.ConstantDeviceOp(ir.StringAttr.get(arg.type)).result
|
||||
if isinstance(arg, torch.memory_format):
|
||||
assert isinstance(expected_type, torch.IntType)
|
||||
return self._import_argument(MEMORY_FORMAT_TO_INT[arg],
|
||||
expected_type)
|
||||
return self._import_argument(MEMORY_FORMAT_TO_INT[arg], expected_type)
|
||||
if isinstance(arg, torch.layout):
|
||||
assert isinstance(expected_type, torch.IntType)
|
||||
return self._import_argument(LAYOUT_TO_INT[arg], expected_type)
|
||||
|
@ -409,14 +387,14 @@ class _FXGraphImporter:
|
|||
if isinstance(element_type, torch.TensorType):
|
||||
assert all(
|
||||
torch.fx.node.map_aggregate(
|
||||
arg, lambda a: _is_valid_meta_val(a.meta.get("val"))))
|
||||
arg, lambda a: _is_valid_meta_val(a.meta.get("val"))
|
||||
)
|
||||
)
|
||||
els = [self._env[e, 0] for e in arg]
|
||||
|
||||
else:
|
||||
element_type = _torch_type_to_mlir_type(element_type)
|
||||
els = [
|
||||
self._import_argument(e, element_type) for e in arg
|
||||
]
|
||||
els = [self._import_argument(e, element_type) for e in arg]
|
||||
|
||||
# import pydevd_pycharm
|
||||
# pydevd_pycharm.settrace('localhost', port=8888, stdoutToServer=True, stderrToServer=True)
|
||||
|
|
|
@ -3,6 +3,5 @@ import torch
|
|||
|
||||
# Register _torch_mlir_custom_op_example.identity as a side-effect of importing.
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
lib = os.path.join(*[current_dir, 'libtorch_mlir_custom_op_example.so'])
|
||||
lib = os.path.join(*[current_dir, "libtorch_mlir_custom_op_example.so"])
|
||||
torch.ops.load_library(lib)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
from packaging import version
|
||||
import torch
|
||||
|
||||
|
||||
def torch_version_for_comparison():
|
||||
# Ignore +cpu, +cu117m, etc. in comparisons
|
||||
return version.parse(torch.__version__.split("+", 1)[0])
|
||||
|
|
|
@ -4,20 +4,20 @@
|
|||
import sys
|
||||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
path = sys.argv[1] # dummy script path
|
||||
file_name = sys.argv[2] # dummy script
|
||||
|
||||
contents = '''
|
||||
contents = """
|
||||
# This file was automatically generated due to LTC being disabled in build.
|
||||
|
||||
class LazyTensorCoreTestConfig:
|
||||
def __init__(self):
|
||||
assert False, "LTC is not enabled. Check the value of `TORCH_MLIR_ENABLE_LTC`"
|
||||
'''
|
||||
"""
|
||||
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
with open(os.path.join(path, file_name + '.py'), 'w') as file:
|
||||
with open(os.path.join(path, file_name + ".py"), "w") as file:
|
||||
file.write(contents)
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch._dynamo.backends.common import aot_autograd
|
|||
import functorch
|
||||
|
||||
import warnings
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/89064
|
||||
warnings.filterwarnings("ignore", module="torch.jit._check")
|
||||
|
||||
|
@ -91,8 +92,7 @@ def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
|
|||
did_convert_list_to_tuple = False
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert len(node.args) == 1, \
|
||||
"Output node must have a single argument"
|
||||
assert len(node.args) == 1, "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
|
@ -106,7 +106,7 @@ def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
|
|||
did_convert_list_to_tuple = True
|
||||
break
|
||||
else:
|
||||
node.args= (tuple(node_arg),)
|
||||
node.args = (tuple(node_arg),)
|
||||
did_convert_list_to_tuple = True
|
||||
break
|
||||
|
||||
|
@ -129,10 +129,12 @@ def make_simple_dynamo_backend(user_backend):
|
|||
Returns:
|
||||
A function with the signature used by TorchDynamo backends.
|
||||
"""
|
||||
def wrapper_backend(gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
did_unwrap_single_element, did_convert_list_to_tuple = \
|
||||
_adjust_calling_convention(gm)
|
||||
|
||||
def wrapper_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
(
|
||||
did_unwrap_single_element,
|
||||
did_convert_list_to_tuple,
|
||||
) = _adjust_calling_convention(gm)
|
||||
strip_overloads(gm)
|
||||
user_callable = user_backend(gm, example_inputs)
|
||||
|
||||
|
@ -147,6 +149,9 @@ def make_simple_dynamo_backend(user_backend):
|
|||
if did_convert_list_to_tuple:
|
||||
result = list(result)
|
||||
return result
|
||||
|
||||
return dynamo_callable
|
||||
return aot_autograd(fw_compiler=wrapper_backend,
|
||||
decompositions=_get_decomposition_table)
|
||||
|
||||
return aot_autograd(
|
||||
fw_compiler=wrapper_backend, decompositions=_get_decomposition_table
|
||||
)
|
||||
|
|
|
@ -15,24 +15,31 @@ from torch_mlir.passmanager import PassManager
|
|||
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
def all_integer_dtypes() -> List[int]:
|
||||
return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||
|
||||
|
||||
def is_integer_dtype(dtype: int) -> bool:
|
||||
return dtype in all_integer_dtypes()
|
||||
|
||||
|
||||
def all_complex_dtypes() -> List[int]:
|
||||
return [torch.complex64, torch.complex128]
|
||||
|
||||
|
||||
def is_complex_dtype(dtype: int) -> bool:
|
||||
return dtype in all_complex_dtypes()
|
||||
|
||||
|
||||
def all_float_dtypes() -> List[int]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32, torch.float64]
|
||||
|
||||
|
||||
def is_float_dtype(dtype: int) -> bool:
|
||||
return dtype in all_float_dtypes()
|
||||
|
||||
|
||||
def get_priority_of_dtype(dtype: int) -> int:
|
||||
# If a loop is used to iterate over a list of sorted dtypes, TorchScript
|
||||
# produces a loop with INT64_MAX max trip count, which causes problems
|
||||
|
@ -64,6 +71,7 @@ def get_priority_of_dtype(dtype: int) -> int:
|
|||
return 11
|
||||
assert False, "Cannot determine priority of dtype"
|
||||
|
||||
|
||||
def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int:
|
||||
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
|
||||
# that when `jit.script`ed converts a float scalar to a tensor
|
||||
|
@ -83,6 +91,7 @@ def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int:
|
|||
# op.
|
||||
return torch.ops.prim.NumToTensor(scalar).dtype
|
||||
|
||||
|
||||
# When we import into torch-mlir, only the calls to
|
||||
# `__torch_mlir_internal_promote_dtypes` are used to generate the
|
||||
# `torch.promote_dtypes` ops. Therefore, to avoid generating extra
|
||||
|
@ -97,23 +106,29 @@ def _get_scalar_with_dtype(dtype: torch.dtype) -> Union[int, float]:
|
|||
else:
|
||||
raise ValueError(f"Unhandled dtype: {dtype}")
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _promote_scalar_tensor(scalar_dtype: torch.dtype, tensor_rank: int,
|
||||
tensor_dtype: torch.dtype) -> torch.dtype:
|
||||
def _promote_scalar_tensor(
|
||||
scalar_dtype: torch.dtype, tensor_rank: int, tensor_dtype: torch.dtype
|
||||
) -> torch.dtype:
|
||||
scalar = _get_scalar_with_dtype(scalar_dtype)
|
||||
tensor = torch.rand([1] * tensor_rank).to(tensor_dtype)
|
||||
return torch.result_type(scalar, tensor)
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _promote_tensor_tensor(lhs_rank: int, lhs_dtype: torch.dtype,
|
||||
rhs_rank: int, rhs_dtype: torch.dtype) -> torch.dtype:
|
||||
def _promote_tensor_tensor(
|
||||
lhs_rank: int, lhs_dtype: torch.dtype, rhs_rank: int, rhs_dtype: torch.dtype
|
||||
) -> torch.dtype:
|
||||
lhs_tensor = torch.rand([1] * lhs_rank).to(lhs_dtype)
|
||||
rhs_tensor = torch.rand([1] * rhs_rank).to(rhs_dtype)
|
||||
return torch.result_type(lhs_tensor, rhs_tensor)
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _promote_scalar_scalar(lhs_dtype: torch.dtype,
|
||||
rhs_dtype: torch.dtype) -> torch.dtype:
|
||||
def _promote_scalar_scalar(
|
||||
lhs_dtype: torch.dtype, rhs_dtype: torch.dtype
|
||||
) -> torch.dtype:
|
||||
# When `torch.result_type` is used on two scalars, the result
|
||||
# dtype is the dtype one would expect for an op with signature
|
||||
# (Scalar, Scalar) -> (Tensor). However, once a module gets
|
||||
|
@ -122,15 +137,17 @@ def _promote_scalar_scalar(lhs_dtype: torch.dtype,
|
|||
# dtype, we use the tensor-tensor promotion rules.
|
||||
return _promote_tensor_tensor(0, lhs_dtype, 0, rhs_dtype)
|
||||
|
||||
def promote_dtypes(ranks: List[Optional[int]],
|
||||
dtypes: List[torch.dtype]) -> torch.dtype:
|
||||
"""Apply PyTorch dtype promotion rules and return the result type.
|
||||
"""
|
||||
|
||||
def promote_dtypes(
|
||||
ranks: List[Optional[int]], dtypes: List[torch.dtype]
|
||||
) -> torch.dtype:
|
||||
"""Apply PyTorch dtype promotion rules and return the result type."""
|
||||
return __torch_mlir_internal_promote_dtypes(ranks, dtypes)
|
||||
|
||||
def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]],
|
||||
dtypes: List[torch.dtype]
|
||||
) -> torch.dtype:
|
||||
|
||||
def __torch_mlir_internal_promote_dtypes(
|
||||
ranks: List[Optional[int]], dtypes: List[torch.dtype]
|
||||
) -> torch.dtype:
|
||||
"""Apply PyTorch dtype promotion rules and return the result type.
|
||||
|
||||
This function serves two purposes:
|
||||
|
@ -145,18 +162,18 @@ def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]],
|
|||
if lhs_optional_rank is None and rhs_optional_rank is None:
|
||||
lhs_dtype = _promote_scalar_scalar(lhs_dtype, rhs_dtype)
|
||||
elif lhs_optional_rank is None and rhs_optional_rank is not None:
|
||||
lhs_dtype = _promote_scalar_tensor(
|
||||
lhs_dtype, rhs_optional_rank, rhs_dtype)
|
||||
lhs_dtype = _promote_scalar_tensor(lhs_dtype, rhs_optional_rank, rhs_dtype)
|
||||
lhs_optional_rank = rhs_optional_rank
|
||||
elif lhs_optional_rank is not None and rhs_optional_rank is None:
|
||||
lhs_dtype = _promote_scalar_tensor(
|
||||
rhs_dtype, lhs_optional_rank, lhs_dtype)
|
||||
lhs_dtype = _promote_scalar_tensor(rhs_dtype, lhs_optional_rank, lhs_dtype)
|
||||
elif lhs_optional_rank is not None and rhs_optional_rank is not None:
|
||||
lhs_dtype = _promote_tensor_tensor(
|
||||
lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype)
|
||||
lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype
|
||||
)
|
||||
lhs_optional_rank = max(lhs_optional_rank, rhs_optional_rank)
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def not_present_in_registry(f):
|
||||
"""Decorator for abstract interpretation functions not present in the registry.
|
||||
|
||||
|
@ -175,6 +192,7 @@ def not_present_in_registry(f):
|
|||
f._not_present_in_registry = None
|
||||
return f
|
||||
|
||||
|
||||
def _verify_signature_matches_registry(f, registry: Registry):
|
||||
source = inspect.getsource(f)
|
||||
signature = None
|
||||
|
@ -183,7 +201,9 @@ def _verify_signature_matches_registry(f, registry: Registry):
|
|||
signature = line
|
||||
break
|
||||
assert signature is not None, f"Could not find signature for {f.__name__}"
|
||||
assert "〡" in signature, f"Malformed signature {signature}. Signature missing the character `〡`"
|
||||
assert (
|
||||
"〡" in signature
|
||||
), f"Malformed signature {signature}. Signature missing the character `〡`"
|
||||
function_name, function_kind = f.__name__.split("〡")
|
||||
atoms = function_name.split("〇")
|
||||
if len(atoms) == 2:
|
||||
|
@ -203,7 +223,10 @@ def _verify_signature_matches_registry(f, registry: Registry):
|
|||
else:
|
||||
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
|
||||
if signature != expected_signature:
|
||||
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
|
||||
raise ValueError(
|
||||
f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}"
|
||||
)
|
||||
|
||||
|
||||
def generate_library(functions: Dict[str, Any]) -> str:
|
||||
"""Convert all op functions in `functions` into MLIR."""
|
||||
|
@ -245,11 +268,13 @@ def generate_library(functions: Dict[str, Any]) -> str:
|
|||
# the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}`
|
||||
# The extra namespaces are not part of the abstract interpretation
|
||||
# function name, so here we simply drop the extra namespaces.
|
||||
namespace = fr"(?:{name}\.)"
|
||||
namespace = rf"(?:{name}\.)"
|
||||
|
||||
asm = re.sub(fr'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"',
|
||||
fr'@"__torch_mlir_\3_fn.\1{circle}\2"',
|
||||
asm)
|
||||
asm = re.sub(
|
||||
rf'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"',
|
||||
rf'@"__torch_mlir_\3_fn.\1{circle}\2"',
|
||||
asm,
|
||||
)
|
||||
|
||||
# Put the `〇` back to a regular `.`.
|
||||
asm = asm.replace(codecs.decode(circle, "unicode_escape"), ".")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
@ -15,13 +14,17 @@ import difflib
|
|||
from .utils import TextEmitter
|
||||
|
||||
# Note that this utility exists only in the c-extension.
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import (
|
||||
get_registered_ops,
|
||||
) # pytype: disable=import-error
|
||||
|
||||
|
||||
def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
|
||||
if parameter_name == "from":
|
||||
parameter_name = "from_" # Avoid using a Python keyword.
|
||||
return parameter_name
|
||||
|
||||
|
||||
def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
||||
default = ""
|
||||
if "default_debug" in arg:
|
||||
|
@ -40,8 +43,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
|||
if default_list == "[]":
|
||||
default_debug = "()"
|
||||
else:
|
||||
default_debug = default_list.replace(
|
||||
"[", "(").replace("]", ",)")
|
||||
default_debug = default_list.replace("[", "(").replace("]", ",)")
|
||||
elif arg["pytype"] == "str":
|
||||
default_debug = repr(arg["default_debug"]).replace("'", '"')
|
||||
else:
|
||||
|
@ -49,6 +51,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
|||
default = f" = {default_debug}"
|
||||
return default
|
||||
|
||||
|
||||
def _pytype_to_fn_pytype_common(pytype: str) -> str:
|
||||
if "number" in pytype:
|
||||
return pytype.replace("number", "Union[int, float, complex]")
|
||||
|
@ -65,6 +68,7 @@ def _pytype_to_fn_pytype_common(pytype: str) -> str:
|
|||
return "Any"
|
||||
return pytype
|
||||
|
||||
|
||||
def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
||||
"""Convert a JitOperator pytype to the type relevant in shape functions.
|
||||
|
||||
|
@ -86,6 +90,7 @@ def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
|||
return pytype.replace("Tensor", "List[int]")
|
||||
return _pytype_to_fn_pytype_common(pytype)
|
||||
|
||||
|
||||
def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
|
||||
"""Convert a JitOperator pytype to the type relevant in dtype functions.
|
||||
|
||||
|
@ -97,13 +102,15 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
|
|||
return pytype.replace("Tensor", "Tuple[int, int]")
|
||||
return _pytype_to_fn_pytype_common(pytype)
|
||||
|
||||
|
||||
def _pytype_to_decomposition_fn_pytype(pytype: str) -> str:
|
||||
"""Convert a JitOperator pytype to the type relevant in decomposition functions.
|
||||
"""
|
||||
"""Convert a JitOperator pytype to the type relevant in decomposition functions."""
|
||||
return _pytype_to_fn_pytype_common(pytype)
|
||||
|
||||
|
||||
class JitOperator:
|
||||
"""Information about a single registered `torch::jit::Operator`"""
|
||||
|
||||
def __init__(self, op_info: "OP_INFO_DICT"):
|
||||
"""Create a JitOperator from the raw OP_INFO_DICT extracted from
|
||||
the PyTorch JIT operator registry.
|
||||
|
@ -170,6 +177,7 @@ class JitOperator:
|
|||
are useful in the repr for cross referencing, and it's useful to have
|
||||
them in a single point of truth.
|
||||
"""
|
||||
|
||||
def uppercase_first_letter(s):
|
||||
if not s:
|
||||
return s
|
||||
|
@ -184,15 +192,19 @@ class JitOperator:
|
|||
for op_name_atom in op_name_atoms:
|
||||
for s in op_name_atom.split("_"):
|
||||
op_class_name_atoms.append(s if s else "_")
|
||||
cpp_class_name = "".join(
|
||||
uppercase_first_letter(s) for s in op_class_name_atoms) + "Op"
|
||||
cpp_class_name = (
|
||||
"".join(uppercase_first_letter(s) for s in op_class_name_atoms) + "Op"
|
||||
)
|
||||
# Disallow leading underscores in C++ to avoid illegal names.
|
||||
cpp_class_name = cpp_class_name.lstrip("_")
|
||||
return op_name, cpp_class_name
|
||||
|
||||
def _get_function_signature(self, function_kind: str,
|
||||
def _get_function_signature(
|
||||
self,
|
||||
function_kind: str,
|
||||
parameter_decl_builder: Callable[["SIG_ATTR_TYPE"], str],
|
||||
ret_decl_builder: Callable[["SIG_ATTR_TYPE"], str]) -> str:
|
||||
ret_decl_builder: Callable[["SIG_ATTR_TYPE"], str],
|
||||
) -> str:
|
||||
mlir_op_name, _ = self.get_mlir_names()
|
||||
# Replace `.` with a valid Python identifier character.
|
||||
# `〇` vaguely looks like `.`.
|
||||
|
@ -219,6 +231,7 @@ class JitOperator:
|
|||
ops have extra default arguments and stuff that are tedious to write out
|
||||
right.
|
||||
"""
|
||||
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||
default = _get_default_value(arg)
|
||||
|
@ -229,7 +242,8 @@ class JitOperator:
|
|||
return _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||
|
||||
return self._get_function_signature(
|
||||
"shape", parameter_decl_builder, ret_decl_builder)
|
||||
"shape", parameter_decl_builder, ret_decl_builder
|
||||
)
|
||||
|
||||
def get_dtype_function_signature(self):
|
||||
"""Gets the Python function signature for this op's dtype function.
|
||||
|
@ -239,6 +253,7 @@ class JitOperator:
|
|||
ops have extra default arguments and stuff that are tedious to write out
|
||||
right.
|
||||
"""
|
||||
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
pytype = _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||
default = _get_default_value(arg)
|
||||
|
@ -257,7 +272,8 @@ class JitOperator:
|
|||
return _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||
|
||||
return self._get_function_signature(
|
||||
"dtype", parameter_decl_builder, ret_decl_builder)
|
||||
"dtype", parameter_decl_builder, ret_decl_builder
|
||||
)
|
||||
|
||||
def get_decomposition_function_signature(self):
|
||||
"""Gets the Python function signature for this op's decomposition function.
|
||||
|
@ -267,6 +283,7 @@ class JitOperator:
|
|||
ops have extra default arguments and stuff that are tedious to write out
|
||||
right.
|
||||
"""
|
||||
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||
default = _get_default_value(arg)
|
||||
|
@ -277,7 +294,8 @@ class JitOperator:
|
|||
return _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||
|
||||
return self._get_function_signature(
|
||||
"decomposition", parameter_decl_builder, ret_decl_builder)
|
||||
"decomposition", parameter_decl_builder, ret_decl_builder
|
||||
)
|
||||
|
||||
def get_has_value_semantics_function_signature(self):
|
||||
"""Gets the Python function signature for this op's has_value_semantics function.
|
||||
|
@ -287,6 +305,7 @@ class JitOperator:
|
|||
ops have extra default arguments and stuff that are tedious to write out
|
||||
right.
|
||||
"""
|
||||
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
return ""
|
||||
|
||||
|
@ -294,7 +313,8 @@ class JitOperator:
|
|||
return ""
|
||||
|
||||
return self._get_function_signature(
|
||||
"has_value_semantics", parameter_decl_builder, ret_decl_builder)
|
||||
"has_value_semantics", parameter_decl_builder, ret_decl_builder
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
f = io.StringIO()
|
||||
|
@ -318,7 +338,9 @@ class JitOperator:
|
|||
p(f"is_mutable = {self.is_mutable}")
|
||||
if any(ret["type"] == "Tensor" for ret in self.returns):
|
||||
p(f"shape_function_signature = {self.get_shape_function_signature()}")
|
||||
p(f"decomposition_function_signature = {self.get_decomposition_function_signature()}")
|
||||
p(
|
||||
f"decomposition_function_signature = {self.get_decomposition_function_signature()}"
|
||||
)
|
||||
if any(ret["type"] in ["Tensor", "Scalar"] for ret in self.returns):
|
||||
p(f"dtype_function_signature = {self.get_dtype_function_signature()}")
|
||||
|
||||
|
@ -354,7 +376,9 @@ class JitOperator:
|
|||
# Note that this is different from MLIR's NoSideEffect which is much
|
||||
# stronger (for example, it cannot be applied to ops that might emit errors
|
||||
# when operand shapes mismatch).
|
||||
if any("alias_info" in x for x in itertools.chain(self.arguments, self.returns)):
|
||||
if any(
|
||||
"alias_info" in x for x in itertools.chain(self.arguments, self.returns)
|
||||
):
|
||||
return False
|
||||
# It seems the FunctionSchema of "prim::unchecked_cast : (t) -> (t)" has
|
||||
# incorrect alias information. The result can alias with other tensors
|
||||
|
@ -363,8 +387,10 @@ class JitOperator:
|
|||
return False
|
||||
# The `is` operator compares object identity, so it does not have
|
||||
# value semantics.
|
||||
if self.unique_key in ("aten::__is__ : (t1, t2) -> (bool)",
|
||||
"aten::__isnot__ : (t1, t2) -> (bool)"):
|
||||
if self.unique_key in (
|
||||
"aten::__is__ : (t1, t2) -> (bool)",
|
||||
"aten::__isnot__ : (t1, t2) -> (bool)",
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -390,6 +416,7 @@ class JitOperator:
|
|||
|
||||
class Registry:
|
||||
"""An indexed collection of JitOperators"""
|
||||
|
||||
def __init__(self, operators: List[JitOperator]):
|
||||
self.by_unique_key = {}
|
||||
self.by_triple = {}
|
||||
|
@ -434,4 +461,3 @@ SIGLIST_TYPE = List[SIG_ATTR_TYPE]
|
|||
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
|
||||
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
|
||||
OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]]
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ from torch import Tensor
|
|||
# The typical iteration flow is to add invocations to the list and then re-run
|
||||
# `build_tools/update_abstract_interp_lib.sh` to re-run the tests.
|
||||
|
||||
|
||||
class TensorOfShape:
|
||||
"""Symbolic placeholder for a tensor argument to an operation.
|
||||
|
||||
|
@ -60,30 +61,40 @@ class TensorOfShape:
|
|||
This class also tracks a dtype of the tensor, since some ops require a
|
||||
specific dtype.
|
||||
"""
|
||||
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*shape: int,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
self.shape = list(shape)
|
||||
self.dtype = dtype
|
||||
self.device = "meta" if device is None else device
|
||||
|
||||
def __repr__(self):
|
||||
args_str = ", ".join(repr(x) for x in self.shape)
|
||||
return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})"
|
||||
|
||||
|
||||
def LongTensorOfShape(*args, **kwargs):
|
||||
"""Helper for indicating a TensorOfShape with integer type."""
|
||||
return TensorOfShape(*args, **kwargs, dtype=torch.long)
|
||||
|
||||
|
||||
def NonZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
||||
"""Helper for indicating a non-zero dim tensor with custom type."""
|
||||
return TensorOfShape(1, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def ZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
||||
"""Helper for indicating a zero dim tensor with custom type."""
|
||||
return TensorOfShape(dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _recursively_transform_tensor_args(
|
||||
o: Any,
|
||||
tensor_transformer: Callable[[TensorOfShape], Any]) -> Any:
|
||||
o: Any, tensor_transformer: Callable[[TensorOfShape], Any]
|
||||
) -> Any:
|
||||
"""Replace `TensorOfShape` with the result of `tensor_transformer`"""
|
||||
if o is None or isinstance(o, (float, int, str)):
|
||||
return o
|
||||
|
@ -92,9 +103,12 @@ def _recursively_transform_tensor_args(
|
|||
if isinstance(o, list):
|
||||
return [_recursively_transform_tensor_args(x, tensor_transformer) for x in o]
|
||||
if isinstance(o, tuple):
|
||||
return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o)
|
||||
return tuple(
|
||||
_recursively_transform_tensor_args(x, tensor_transformer) for x in o
|
||||
)
|
||||
raise Exception(f"Unhandled type {type(o)}")
|
||||
|
||||
|
||||
class Invocation:
|
||||
"""Representation of a single op invocation (i.e. list of args to the op).
|
||||
|
||||
|
@ -111,6 +125,7 @@ class Invocation:
|
|||
exception for greater precision when interpreting errors raised during
|
||||
testing.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
self.args = list(args)
|
||||
# We assume kwargs don't contain tensors, so they don't need any
|
||||
|
@ -134,14 +149,12 @@ class Invocation:
|
|||
# are ok since they make it a bit easier to write some shape
|
||||
# functions.
|
||||
tensor_transformer = lambda o: list(o.shape)
|
||||
return _recursively_transform_tensor_args(
|
||||
self.args, tensor_transformer)
|
||||
return _recursively_transform_tensor_args(self.args, tensor_transformer)
|
||||
|
||||
def to_dtype_function_args(self):
|
||||
"""Gets positional arguments appropriate for a dtype function."""
|
||||
tensor_transformer = lambda o: (len(o.shape), o.dtype)
|
||||
return _recursively_transform_tensor_args(
|
||||
self.args, tensor_transformer)
|
||||
return _recursively_transform_tensor_args(self.args, tensor_transformer)
|
||||
|
||||
def to_real_op_args(self):
|
||||
"""Gets positional arguments appropriate for the real op."""
|
||||
|
@ -155,6 +168,7 @@ class Invocation:
|
|||
kwargs_str = ", " + ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
|
||||
return f"Invocation({args_str}{kwargs_str})"
|
||||
|
||||
|
||||
class ErrorInvocation(Invocation):
|
||||
"""An Invocation that raises an exception.
|
||||
|
||||
|
@ -165,9 +179,11 @@ class ErrorInvocation(Invocation):
|
|||
spurioiusly make the two appear to "agree" that an exception needs to be
|
||||
raised).
|
||||
"""
|
||||
|
||||
def is_expected_to_raise_exception(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _normalize_multiple_results_to_list(t: Any):
|
||||
"""Returns a flat list of results.
|
||||
|
||||
|
@ -182,9 +198,13 @@ def _normalize_multiple_results_to_list(t: Any):
|
|||
return [t]
|
||||
raise ValueError(f"Unexpected type {type(t)}")
|
||||
|
||||
|
||||
def _report(f, invocation: Invocation, error_message: str):
|
||||
fn_type = f.__name__.split("〡")[-1]
|
||||
raise ValueError(f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}")
|
||||
raise ValueError(
|
||||
f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}"
|
||||
)
|
||||
|
||||
|
||||
def _get_fn_and_golden_results(f, invocation: List[Invocation]):
|
||||
"""Run the invocation on the library function and torch op.
|
||||
|
@ -201,36 +221,64 @@ def _get_fn_and_golden_results(f, invocation: List[Invocation]):
|
|||
op = getattr(getattr(getattr(torch.ops, ns), unqual), overload)
|
||||
fn_error, op_error, fn_results, golden_results = None, None, None, None
|
||||
try:
|
||||
fn_results = _normalize_multiple_results_to_list(f(
|
||||
fn_results = _normalize_multiple_results_to_list(
|
||||
f(
|
||||
*(getattr(invocation, f"to_{fn_type}_function_args")()),
|
||||
**invocation.kwargs))
|
||||
**invocation.kwargs,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
fn_error = f"{e}"
|
||||
try:
|
||||
golden_results = _normalize_multiple_results_to_list(op(
|
||||
*invocation.to_real_op_args(),
|
||||
**invocation.kwargs))
|
||||
golden_results = _normalize_multiple_results_to_list(
|
||||
op(*invocation.to_real_op_args(), **invocation.kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
op_error = f"{e}"
|
||||
|
||||
# Check for error behavior.
|
||||
if invocation.is_expected_to_raise_exception():
|
||||
if fn_error is None and op_error is None:
|
||||
_report(f, invocation, f"Expected to raise an exception, but neither {fn_type} function or op raised an exception")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Expected to raise an exception, but neither {fn_type} function or op raised an exception",
|
||||
)
|
||||
if fn_error is None:
|
||||
_report(f, invocation, f"Op raised error {op_error!r}, but shape/dtype function did not.")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Op raised error {op_error!r}, but shape/dtype function did not.",
|
||||
)
|
||||
if op_error is None:
|
||||
_report(f, invocation, f"{fn_type} function raised error {fn_error!r}, but op did not.")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"{fn_type} function raised error {fn_error!r}, but op did not.",
|
||||
)
|
||||
else:
|
||||
if fn_error is not None and op_error is not None:
|
||||
_report(f, invocation, f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.",
|
||||
)
|
||||
if fn_error is not None:
|
||||
_report(f, invocation, f"{fn_type} function raised error {fn_error!r} but op did not raise any error.")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"{fn_type} function raised error {fn_error!r} but op did not raise any error.",
|
||||
)
|
||||
if op_error is not None:
|
||||
_report(f, invocation, f"Op raised error {op_error!r} but {fn_type} function did not raise any error.")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Op raised error {op_error!r} but {fn_type} function did not raise any error.",
|
||||
)
|
||||
|
||||
return fn_results, golden_results
|
||||
|
||||
|
||||
def check_shape_function(invocations: List[Invocation]):
|
||||
"""Decorator that automatically tests a shape function.
|
||||
|
||||
|
@ -238,6 +286,7 @@ def check_shape_function(invocations: List[Invocation]):
|
|||
`〇` instead of `.`, is tested against the corresponding op in
|
||||
`torch.ops.*` function using the given invocations.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
for invocation in invocations:
|
||||
result_shapes, golden_results = _get_fn_and_golden_results(f, invocation)
|
||||
|
@ -245,18 +294,34 @@ def check_shape_function(invocations: List[Invocation]):
|
|||
continue
|
||||
# Check for matching results.
|
||||
if len(result_shapes) != len(golden_results):
|
||||
_report(f, invocation, f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}",
|
||||
)
|
||||
for result_shape, golden_result in zip(result_shapes, golden_results):
|
||||
result_rank = len(result_shape)
|
||||
golden_rank = len(golden_result.shape)
|
||||
if result_rank != golden_rank:
|
||||
_report(f, invocation, f"Expected result rank {golden_rank}, got {result_rank}")
|
||||
for dimension_size, golden_dimension_size in zip(result_shape, golden_result.shape):
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Expected result rank {golden_rank}, got {result_rank}",
|
||||
)
|
||||
for dimension_size, golden_dimension_size in zip(
|
||||
result_shape, golden_result.shape
|
||||
):
|
||||
if dimension_size != golden_dimension_size:
|
||||
_report(f, invocation, f"Expected result shape {golden_result.shape}, got {result_shape}")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Expected result shape {golden_result.shape}, got {result_shape}",
|
||||
)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _convert_dtype_to_int(dtype: torch.dtype) -> int:
|
||||
"""Convert a PyTorch `dtype` into its underlying `int` representation.
|
||||
|
@ -266,6 +331,7 @@ def _convert_dtype_to_int(dtype: torch.dtype) -> int:
|
|||
"""
|
||||
return dtype
|
||||
|
||||
|
||||
def check_dtype_function(invocations: List[Invocation]):
|
||||
"""Decorator that automatically tests a dtype function.
|
||||
|
||||
|
@ -273,6 +339,7 @@ def check_dtype_function(invocations: List[Invocation]):
|
|||
`〇` instead of `.`, is tested against the corresponding op in
|
||||
`torch.ops.*` function using the given invocations.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
for invocation in invocations:
|
||||
result_dtypes, golden_results = _get_fn_and_golden_results(f, invocation)
|
||||
|
@ -280,7 +347,11 @@ def check_dtype_function(invocations: List[Invocation]):
|
|||
continue
|
||||
|
||||
if len(result_dtypes) != len(golden_results):
|
||||
_report(f, invocation, f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}")
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}",
|
||||
)
|
||||
for result_dtype, golden_result in zip(result_dtypes, golden_results):
|
||||
if isinstance(golden_result, torch.Tensor):
|
||||
golden_dtype = golden_result.dtype
|
||||
|
@ -294,7 +365,14 @@ def check_dtype_function(invocations: List[Invocation]):
|
|||
# support returning the default `int` value, the comparisons of
|
||||
# the result and golden dtypes are done using their underlying
|
||||
# `int` representation.
|
||||
if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype):
|
||||
_report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}")
|
||||
if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(
|
||||
golden_dtype
|
||||
):
|
||||
_report(
|
||||
f,
|
||||
invocation,
|
||||
f"Expected result dtype {golden_dtype}, got {result_dtype}",
|
||||
)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -71,18 +71,23 @@ def get_ods_type(type: str, non_value: bool, *, is_result: bool = False):
|
|||
if type.startswith("Dict("):
|
||||
type = "Dict"
|
||||
if non_value:
|
||||
ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get(type) or TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||
ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get(
|
||||
type
|
||||
) or TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||
else:
|
||||
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||
if ods_type is None:
|
||||
raise Exception(
|
||||
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!")
|
||||
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!"
|
||||
)
|
||||
return ods_type
|
||||
|
||||
|
||||
def _name_thunk() -> None:
|
||||
# Strictly exists for _get_main_module_name to harvest its __module__.
|
||||
pass
|
||||
|
||||
|
||||
def _get_main_module_name() -> str:
|
||||
# If a Python module is loaded interactively or as part of a module
|
||||
# directory, it uses a BuiltinImporter. If loaded from a file, it uses
|
||||
|
@ -93,6 +98,7 @@ def _get_main_module_name() -> str:
|
|||
except AttributeError:
|
||||
return _name_thunk.__module__
|
||||
|
||||
|
||||
ODS_BANNER = f"""//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
|
@ -117,10 +123,15 @@ ODS_BANNER = f"""//===-------------------------------------------------------*-
|
|||
"""
|
||||
|
||||
|
||||
def raw_emit_op(operator: JitOperator,
|
||||
def raw_emit_op(
|
||||
operator: JitOperator,
|
||||
emitter_td: TextEmitter,
|
||||
*, traits: List[str],
|
||||
has_folder: bool, has_canonicalizer: bool, has_verifier: bool):
|
||||
*,
|
||||
traits: List[str],
|
||||
has_folder: bool,
|
||||
has_canonicalizer: bool,
|
||||
has_verifier: bool,
|
||||
):
|
||||
"""Emit the ODS for a JitOperator to a textual file.
|
||||
|
||||
This is the lowest level of emission and is responsible for low-level
|
||||
|
@ -138,8 +149,7 @@ def raw_emit_op(operator: JitOperator,
|
|||
def generic_result_name(i):
|
||||
return "result" + (str(i) if multiple_results else "")
|
||||
|
||||
p_td(
|
||||
f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [")
|
||||
p_td(f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [")
|
||||
with emitter_td.indent():
|
||||
with emitter_td.indent():
|
||||
p_td(",\n".join(traits))
|
||||
|
@ -153,20 +163,28 @@ def raw_emit_op(operator: JitOperator,
|
|||
if operator.is_vararg:
|
||||
p_td("Variadic<AnyTorchType>:$operands")
|
||||
else:
|
||||
p_td(",\n".join([
|
||||
p_td(
|
||||
",\n".join(
|
||||
[
|
||||
f"""{get_ods_type(arg["type"], is_non_value_op)}:${arg["name"]}"""
|
||||
for arg in operator.arguments
|
||||
]))
|
||||
]
|
||||
)
|
||||
)
|
||||
p_td(");")
|
||||
p_td(f"let results = (outs")
|
||||
with emitter_td.indent():
|
||||
if operator.is_varret:
|
||||
p_td("Variadic<AnyTorchType>:$results")
|
||||
else:
|
||||
p_td(",\n".join([
|
||||
p_td(
|
||||
",\n".join(
|
||||
[
|
||||
f"""{get_ods_type(ret["type"], is_non_value_op, is_result=True)}:${ret["name"] or generic_result_name(e)}"""
|
||||
for e, ret in enumerate(operator.returns)
|
||||
]))
|
||||
]
|
||||
)
|
||||
)
|
||||
p_td(");")
|
||||
|
||||
if operator.is_vararg or operator.is_varret:
|
||||
|
@ -174,16 +192,19 @@ def raw_emit_op(operator: JitOperator,
|
|||
assembly_operands = "`(` $operands `)`"
|
||||
assembly_operand_types = "qualified(type($operands))"
|
||||
else:
|
||||
assembly_operands = " `,` ".join("$" + arg["name"]
|
||||
for arg in operator.arguments)
|
||||
assembly_operands = " `,` ".join(
|
||||
"$" + arg["name"] for arg in operator.arguments
|
||||
)
|
||||
assembly_operand_types = " `,` ".join(
|
||||
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments)
|
||||
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments
|
||||
)
|
||||
if operator.is_varret:
|
||||
assembly_result_types = "qualified(type($results))"
|
||||
else:
|
||||
assembly_result_types = " `,` ".join(
|
||||
f"""qualified(type(${ret["name"] or generic_result_name(e)}))"""
|
||||
for e, ret in enumerate(operator.returns))
|
||||
for e, ret in enumerate(operator.returns)
|
||||
)
|
||||
if assembly_operand_types and assembly_result_types:
|
||||
maybe_arrow = " `->` "
|
||||
else:
|
||||
|
@ -192,7 +213,8 @@ def raw_emit_op(operator: JitOperator,
|
|||
p_td(f"let assemblyFormat = {emitter_td.quote(assembly_format)};")
|
||||
else:
|
||||
p_td(f"let hasCustomAssemblyFormat = 1;")
|
||||
p_td(f"""let extraClassDefinition = [{{
|
||||
p_td(
|
||||
f"""let extraClassDefinition = [{{
|
||||
ParseResult {cpp_class_name}::parse(OpAsmParser &parser, OperationState &result) {{
|
||||
return parseDefaultTorchOp(parser, result, {len(operator.arguments)}, {len(operator.returns)});
|
||||
}}
|
||||
|
@ -200,7 +222,8 @@ def raw_emit_op(operator: JitOperator,
|
|||
printDefaultTorchOp(printer, *this, {len(operator.arguments)}, {len(operator.returns)});
|
||||
}}
|
||||
}}];
|
||||
""")
|
||||
"""
|
||||
)
|
||||
if has_folder:
|
||||
p_td("let hasFolder = 1;")
|
||||
if has_canonicalizer:
|
||||
|
@ -211,13 +234,15 @@ def raw_emit_op(operator: JitOperator,
|
|||
p_td("\n")
|
||||
|
||||
|
||||
def emit_op(operator: JitOperator,
|
||||
def emit_op(
|
||||
operator: JitOperator,
|
||||
emitter_td: TextEmitter,
|
||||
*,
|
||||
traits: Optional[List[str]] = None,
|
||||
has_folder: bool = False,
|
||||
has_canonicalizer: bool = False,
|
||||
has_verifier: bool = False):
|
||||
has_verifier: bool = False,
|
||||
):
|
||||
"""Main entry point for op emission.
|
||||
|
||||
Besides emitting the op, it deduces / adds traits based on the operator
|
||||
|
@ -233,12 +258,14 @@ def emit_op(operator: JitOperator,
|
|||
if operator.is_readonly():
|
||||
traits += ["ReadOnly"]
|
||||
|
||||
raw_emit_op(operator,
|
||||
raw_emit_op(
|
||||
operator,
|
||||
emitter_td,
|
||||
traits=traits,
|
||||
has_folder=has_folder,
|
||||
has_canonicalizer=has_canonicalizer,
|
||||
has_verifier=has_verifier)
|
||||
has_verifier=has_verifier,
|
||||
)
|
||||
|
||||
|
||||
def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||
|
@ -253,9 +280,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
ns, unqual, overload = operator.triple
|
||||
# Underscore variant of functional ops should have "functional" part removed.
|
||||
is_functional_op = overload == "functional"
|
||||
emit_op(registry.get_by_triple((ns, unqual + "_", overload if not is_functional_op else "")),
|
||||
emit_op(
|
||||
registry.get_by_triple(
|
||||
(ns, unqual + "_", overload if not is_functional_op else "")
|
||||
),
|
||||
emitter_td,
|
||||
traits=["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [])
|
||||
traits=["IsTrailingUnderscoreInplaceVariant"]
|
||||
if not is_functional_op
|
||||
else [],
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# `aten::` namespace.
|
||||
|
@ -332,45 +365,105 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::square : (Tensor) -> (Tensor)",
|
||||
"aten::zero : (Tensor) -> (Tensor)",
|
||||
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)"
|
||||
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
]:
|
||||
emit_with_mutating_variants(key)
|
||||
# Shape manipulations:
|
||||
emit_with_mutating_variants("aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants(
|
||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True
|
||||
)
|
||||
|
||||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
# variants.
|
||||
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants(
|
||||
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants(
|
||||
"aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
|
||||
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)"
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::mish : (Tensor) -> (Tensor)")
|
||||
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit(
|
||||
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
|
@ -392,7 +485,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])")
|
||||
|
||||
# Random number generation
|
||||
emit_with_mutating_variants("aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
|
||||
|
@ -400,11 +495,17 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
|
||||
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
|
||||
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
|
||||
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::random : (Tensor, Generator?) -> (Tensor)")
|
||||
emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)")
|
||||
|
@ -412,10 +513,14 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)"
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
|
||||
emit("aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
|
||||
"aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)"
|
||||
)
|
||||
|
||||
# Non-elementwise tensor compute ops
|
||||
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
||||
|
@ -433,16 +538,32 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
||||
emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
||||
emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)"
|
||||
)
|
||||
emit("aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||
emit("aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
||||
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
|
||||
emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit(
|
||||
"aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||
|
@ -456,43 +577,33 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)'
|
||||
)
|
||||
emit("aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
|
||||
emit(
|
||||
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)",
|
||||
)
|
||||
emit(
|
||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||
)
|
||||
|
@ -505,18 +616,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
emit("aten::softmax.int : (Tensor, int, int?) -> (Tensor)")
|
||||
emit("aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)")
|
||||
emit("aten::_log_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
emit_with_mutating_variants(
|
||||
"aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
|
||||
emit_with_mutating_variants(
|
||||
"aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
|
@ -549,13 +660,23 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
||||
emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)")
|
||||
emit("aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)")
|
||||
emit(
|
||||
"aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
emit(
|
||||
"aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
|
||||
|
@ -563,17 +684,25 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
|
||||
emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)")
|
||||
emit(
|
||||
"aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)"
|
||||
)
|
||||
emit("aten::nonzero : (Tensor) -> (Tensor)")
|
||||
emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])")
|
||||
emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)")
|
||||
emit("aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)")
|
||||
emit(
|
||||
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
||||
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)")
|
||||
emit(
|
||||
"aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
|
@ -590,9 +719,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
||||
emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True)
|
||||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
||||
emit(
|
||||
"aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
||||
emit(
|
||||
"aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
|
@ -611,8 +744,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::any : (Tensor) -> (Tensor)")
|
||||
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
|
||||
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)")
|
||||
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
|
||||
|
@ -624,19 +761,29 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
||||
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
|
||||
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True)
|
||||
emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||
emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)")
|
||||
emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)")
|
||||
emit(
|
||||
"aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
||||
|
@ -644,7 +791,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
|
||||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
|
@ -670,9 +819,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True)
|
||||
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True)
|
||||
emit(
|
||||
"aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit(
|
||||
"aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit(
|
||||
"aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
||||
emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
|
||||
|
@ -681,33 +839,64 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True)
|
||||
emit(
|
||||
"aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit(
|
||||
"aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
has_folder=True,
|
||||
)
|
||||
emit(
|
||||
"aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
|
||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True)
|
||||
emit(
|
||||
"aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)",
|
||||
has_folder=True,
|
||||
)
|
||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||
emit("aten::cpu : (Tensor) -> (Tensor)")
|
||||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
|
||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True)
|
||||
emit(
|
||||
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
||||
emit("aten::t : (Tensor) -> (Tensor)")
|
||||
emit("aten::numpy_T : (Tensor) -> (Tensor)")
|
||||
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit(
|
||||
"aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)",
|
||||
has_folder=True,
|
||||
)
|
||||
emit(
|
||||
"aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
"aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)"
|
||||
)
|
||||
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
||||
|
||||
# Functionalization ops
|
||||
|
@ -737,7 +926,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
||||
emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)")
|
||||
emit(
|
||||
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
||||
|
||||
# Dict ops.
|
||||
|
@ -750,7 +941,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()")
|
||||
|
||||
# List ops.
|
||||
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit(
|
||||
"aten::cat : (Tensor[], int) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::append.t : (t[], t) -> (t[])")
|
||||
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||
|
@ -823,9 +1018,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
|
||||
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
|
||||
emit("aten::len.t : (t[]) -> (int)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True)
|
||||
emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True)
|
||||
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
|
||||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||
emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True)
|
||||
|
@ -849,12 +1042,22 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
|
||||
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
||||
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit(
|
||||
"aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
||||
emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)")
|
||||
emit(
|
||||
"aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||
|
||||
# quantized ops
|
||||
|
@ -863,7 +1066,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::dequantize.self : (Tensor) -> (Tensor)")
|
||||
emit("aten::dequantize.tensor : (Tensor) -> (Tensor)")
|
||||
emit("aten::int_repr : (Tensor) -> (Tensor)")
|
||||
emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||
emit(
|
||||
"aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")
|
||||
|
||||
# ==========================================================================
|
||||
|
@ -881,10 +1086,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("prim::max.self_int : (int[]) -> (int)")
|
||||
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("prim::RaiseException : (str, str?) -> ()")
|
||||
emit("prim::Uninitialized : () -> (Any)",
|
||||
has_canonicalizer=True, traits=["Pure"])
|
||||
emit("prim::unchecked_cast : (t) -> (t)", has_folder=True,
|
||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
||||
emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True, traits=["Pure"])
|
||||
emit(
|
||||
"prim::unchecked_cast : (t) -> (t)",
|
||||
has_folder=True,
|
||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"],
|
||||
)
|
||||
emit("prim::Print : (...) -> ()")
|
||||
emit("prim::tolist : (...) -> (...)")
|
||||
emit("prim::abs.Scalar : (Scalar) -> (Scalar)")
|
||||
|
@ -908,13 +1115,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
|
||||
emit(
|
||||
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
|
||||
traits=["HasValueSemantics"])
|
||||
traits=["HasValueSemantics"],
|
||||
)
|
||||
|
||||
|
||||
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
||||
for _, v in sorted(registry.by_unique_key.items()):
|
||||
outfile.write(repr(v))
|
||||
|
||||
|
||||
def _maybe_import_op_extensions(args: argparse.Namespace):
|
||||
extension_string = str.strip(args.pytorch_op_extensions)
|
||||
if len(extension_string) > 0:
|
||||
|
@ -924,6 +1133,7 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
|
|||
# importing these modules, so we don't need the return value.
|
||||
importlib.import_module(name)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
_maybe_import_op_extensions(args)
|
||||
registry = Registry.load()
|
||||
|
@ -942,15 +1152,18 @@ def _create_argparse() -> argparse.ArgumentParser:
|
|||
parser.add_argument(
|
||||
"--torch_ir_include_dir",
|
||||
required=True,
|
||||
help="Directory in include/ containing the Torch dialect")
|
||||
help="Directory in include/ containing the Torch dialect",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug_registry_dump",
|
||||
help="File to dump the the PyTorch JIT operator registry into")
|
||||
help="File to dump the the PyTorch JIT operator registry into",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_op_extensions",
|
||||
type=str,
|
||||
default="",
|
||||
help="An optional, comma-separated list of Python modules which register additional PyTorch operators upon being imported. These modules can be used to build a torch-mlir which supports PyTorch extensions.")
|
||||
help="An optional, comma-separated list of Python modules which register additional PyTorch operators upon being imported. These modules can be used to build a torch-mlir which supports PyTorch extensions.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
@ -8,8 +8,10 @@ from typing import TextIO
|
|||
from contextlib import contextmanager
|
||||
import textwrap
|
||||
|
||||
|
||||
class TextEmitter:
|
||||
"""Helper for emitting text files"""
|
||||
|
||||
_INDENT = " "
|
||||
|
||||
def __init__(self, out: TextIO):
|
||||
|
|
|
@ -24,34 +24,38 @@ from torch_mlir.jit_ir_importer import ClassAnnotator
|
|||
|
||||
# Utilities for extracting decorated information into ClassAnnotator.
|
||||
|
||||
|
||||
def _recursively_extract_annotations(
|
||||
module: torch.nn.Module, scripted: torch.jit.ScriptModule,
|
||||
class_annotator: ClassAnnotator):
|
||||
module: torch.nn.Module,
|
||||
scripted: torch.jit.ScriptModule,
|
||||
class_annotator: ClassAnnotator,
|
||||
):
|
||||
assert module.__class__.__name__ == scripted.original_name or (
|
||||
isinstance(module, torch.jit.RecursiveScriptModule) and module is
|
||||
scripted), "script module does not come from specified module"
|
||||
isinstance(module, torch.jit.RecursiveScriptModule) and module is scripted
|
||||
), "script module does not come from specified module"
|
||||
|
||||
# Extract information on methods.
|
||||
for method_name, scripted_method in scripted.__dict__.items():
|
||||
if not isinstance(scripted_method, torch.ScriptMethod):
|
||||
continue
|
||||
method = getattr(module, method_name)
|
||||
if hasattr(method, '_torch_mlir_export'):
|
||||
if hasattr(method, "_torch_mlir_export"):
|
||||
class_annotator.exportPath(scripted._c._type(), [method_name])
|
||||
if hasattr(method, '_torch_mlir_arg_annotations'):
|
||||
if hasattr(method, "_torch_mlir_arg_annotations"):
|
||||
class_annotator.annotateArgs(
|
||||
scripted._c._type(), [method_name],
|
||||
method._torch_mlir_arg_annotations)
|
||||
scripted._c._type(), [method_name], method._torch_mlir_arg_annotations
|
||||
)
|
||||
# Recurse.
|
||||
for name, child in module.named_children():
|
||||
scripted_child = getattr(scripted, name)
|
||||
_recursively_extract_annotations(child, scripted_child,
|
||||
class_annotator)
|
||||
_recursively_extract_annotations(child, scripted_child, class_annotator)
|
||||
|
||||
|
||||
def extract_annotations(program: torch.nn.Module,
|
||||
def extract_annotations(
|
||||
program: torch.nn.Module,
|
||||
scripted: torch.jit.ScriptModule,
|
||||
class_annotator: ClassAnnotator):
|
||||
class_annotator: ClassAnnotator,
|
||||
):
|
||||
"""Populate the ClassAnnotator with annotations extracted from `program`."""
|
||||
class_annotator.exportNone(scripted._c._type())
|
||||
_recursively_extract_annotations(program, scripted, class_annotator)
|
||||
|
|
|
@ -20,7 +20,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
|||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
OutputType,
|
||||
lower_mlir_module
|
||||
lower_mlir_module,
|
||||
)
|
||||
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
||||
|
@ -105,8 +105,7 @@ class ExampleArgs:
|
|||
self, for chaining.
|
||||
"""
|
||||
assert method_name not in self._example_args
|
||||
self._example_args[method_name] = ExampleArgs._canonicalize_args(
|
||||
example_args)
|
||||
self._example_args[method_name] = ExampleArgs._canonicalize_args(example_args)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
|
@ -129,10 +128,12 @@ class ExampleArgs:
|
|||
example_args = [example_args]
|
||||
for arg in example_args:
|
||||
if not isinstance(arg, (TensorPlaceholder, torch.Tensor)):
|
||||
raise Exception(f"Only Tensor's, TensorPlaceholder's, or sequences of "
|
||||
raise Exception(
|
||||
f"Only Tensor's, TensorPlaceholder's, or sequences of "
|
||||
f"Tensor's and TensorPlaceholder's are supported as "
|
||||
f"example args for method inputs. "
|
||||
f"Got '{arg}'.")
|
||||
f"Got '{arg}'."
|
||||
)
|
||||
return tuple(example_args)
|
||||
|
||||
def _get_methods(self):
|
||||
|
@ -171,7 +172,8 @@ class ExampleArgs:
|
|||
# "hopefully the trace works for different inputs"
|
||||
# bucket of concerns.
|
||||
raise Exception(
|
||||
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`")
|
||||
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`"
|
||||
)
|
||||
# For any dynamic dimensions, replace them with "7"
|
||||
# arbitrarily. If a user is using dynamic dimensions with
|
||||
# tracing, they are walking on thin ice already -- assume
|
||||
|
@ -182,7 +184,8 @@ class ExampleArgs:
|
|||
example_args_for_trace.append(torch.tensor(1))
|
||||
else:
|
||||
example_args_for_trace.append(
|
||||
torch.ones(*shape, dtype=arg.dtype))
|
||||
torch.ones(*shape, dtype=arg.dtype)
|
||||
)
|
||||
else:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
example_args_for_trace.append(arg)
|
||||
|
@ -198,21 +201,33 @@ class ExampleArgs:
|
|||
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d','aten.adaptive_avg_pool2d', 'aten.unflatten.int'],
|
||||
OutputType.TOSA: [
|
||||
"aten.flatten.using_ints",
|
||||
"aten.native_layer_norm",
|
||||
"aten.linear",
|
||||
],
|
||||
OutputType.LINALG_ON_TENSORS: [
|
||||
"aten.flatten.using_ints",
|
||||
"aten.adaptive_avg_pool1d",
|
||||
"aten.adaptive_avg_pool2d",
|
||||
"aten.unflatten.int",
|
||||
],
|
||||
OutputType.STABLEHLO: [],
|
||||
}
|
||||
|
||||
|
||||
def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra_library.mlir"):
|
||||
def _canon_extra_library(
|
||||
extra_library, extra_library_file_name="custom_op_extra_library.mlir"
|
||||
):
|
||||
if len(extra_library) != 0:
|
||||
extra_library_dict = {}
|
||||
for library_func in extra_library:
|
||||
extra_library_dict[library_func.__name__] = library_func
|
||||
mlir_library = generate_library(extra_library_dict)
|
||||
|
||||
extra_library_file = \
|
||||
os.path.join(tempfile.gettempdir(), extra_library_file_name)
|
||||
extra_library_file = os.path.join(
|
||||
tempfile.gettempdir(), extra_library_file_name
|
||||
)
|
||||
with open(extra_library_file, "w") as f:
|
||||
f.write(mlir_library)
|
||||
return extra_library_file
|
||||
|
@ -220,7 +235,8 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra
|
|||
return ""
|
||||
|
||||
|
||||
def compile(model: torch.nn.Module,
|
||||
def compile(
|
||||
model: torch.nn.Module,
|
||||
example_args: _example_args,
|
||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
use_tracing: bool = False,
|
||||
|
@ -229,7 +245,8 @@ def compile(model: torch.nn.Module,
|
|||
extra_library: Iterable[Callable] = [],
|
||||
verbose: bool = False,
|
||||
use_make_fx: bool = False,
|
||||
enable_ir_printing: bool = False):
|
||||
enable_ir_printing: bool = False,
|
||||
):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
Args:
|
||||
|
@ -283,18 +300,18 @@ def compile(model: torch.nn.Module,
|
|||
# See `BACKEND_LEGAL_OPS` for more details.
|
||||
if backend_legal_ops is not None:
|
||||
if output_type != OutputType.TORCH:
|
||||
raise Exception("`backend_legal_ops` is only valid with the "
|
||||
"`torch` output type")
|
||||
raise Exception(
|
||||
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
||||
)
|
||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
if use_make_fx:
|
||||
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"]
|
||||
model = make_fx(
|
||||
model,
|
||||
decomposition_table=_get_decomposition_table())(*args)
|
||||
|
||||
args = example_args._get_for_tracing(
|
||||
use_tracing=True, ignore_traced_shapes=True
|
||||
)["forward"]
|
||||
model = make_fx(model, decomposition_table=_get_decomposition_table())(*args)
|
||||
|
||||
# For FX-based models, automatically strip overloads.
|
||||
if isinstance(model, torch.fx.GraphModule):
|
||||
|
@ -317,12 +334,12 @@ def compile(model: torch.nn.Module,
|
|||
raise Exception(
|
||||
f"Model does not have exported method '{method_name}', "
|
||||
f"requested in `example_args`. Consider adding "
|
||||
f"`@torch.jit.export` to the method definition.")
|
||||
f"`@torch.jit.export` to the method definition."
|
||||
)
|
||||
scripted = model
|
||||
elif use_tracing:
|
||||
scripted = torch.jit.trace_module(
|
||||
model,
|
||||
example_args._get_for_tracing(use_tracing, ignore_traced_shapes)
|
||||
model, example_args._get_for_tracing(use_tracing, ignore_traced_shapes)
|
||||
)
|
||||
else:
|
||||
# Make sure that all the methods that the user requested get scripted.
|
||||
|
@ -338,8 +355,7 @@ def compile(model: torch.nn.Module,
|
|||
annotation = [None] # `None` is always the annotation for "self".
|
||||
for arg in example_args:
|
||||
annotation.append((arg.shape, arg.dtype, True))
|
||||
class_annotator.annotateArgs(
|
||||
scripted._c._type(), [method_name], annotation)
|
||||
class_annotator.annotateArgs(scripted._c._type(), [method_name], annotation)
|
||||
|
||||
mb = ModuleBuilder()
|
||||
import_options = ImportOptions()
|
||||
|
@ -350,20 +366,27 @@ def compile(model: torch.nn.Module,
|
|||
# Import the TorchScript module to MLIR
|
||||
mb.import_module(scripted._c, class_annotator, import_options)
|
||||
except Exception as e:
|
||||
raise Exception(f"""
|
||||
raise Exception(
|
||||
f"""
|
||||
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||
### Importer C++ Exception:
|
||||
{e}
|
||||
### Importer Diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
""") from None
|
||||
"""
|
||||
) from None
|
||||
finally:
|
||||
sys.stderr = original_stderr
|
||||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
|
||||
" extra-library=" + extra_library_file_name + "}"
|
||||
option_string = (
|
||||
"{backend-legal-ops="
|
||||
+ ",".join(backend_legal_ops)
|
||||
+ " extra-library="
|
||||
+ extra_library_file_name
|
||||
+ "}"
|
||||
)
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
||||
|
|
|
@ -22,8 +22,8 @@ import torch
|
|||
# Attribute names used for annotations.
|
||||
# These should be kept in sync with their use in
|
||||
# `torch_mlir/torchscript_annotations.py`.
|
||||
TORCH_MLIR_EXPORT_ATTR_NAME = '_torch_mlir_export'
|
||||
TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = '_torch_mlir_arg_annotations'
|
||||
TORCH_MLIR_EXPORT_ATTR_NAME = "_torch_mlir_export"
|
||||
TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = "_torch_mlir_arg_annotations"
|
||||
|
||||
|
||||
def export(fn):
|
||||
|
|
|
@ -55,14 +55,20 @@ def jit(
|
|||
output_type = OutputType.get(output_type)
|
||||
if backend_legal_ops is not None:
|
||||
if output_type != OutputType.TORCH:
|
||||
raise Exception("`backend_legal_ops` is only valid with the "
|
||||
"`torch` output type")
|
||||
raise Exception(
|
||||
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
||||
)
|
||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
|
||||
" extra-library=" + extra_library_file_name + "}")
|
||||
option_string = (
|
||||
"{backend-legal-ops="
|
||||
+ ",".join(backend_legal_ops)
|
||||
+ " extra-library="
|
||||
+ extra_library_file_name
|
||||
+ "}"
|
||||
)
|
||||
|
||||
mlir_module = fx.export_and_import(prog, func_name=func_name)
|
||||
assert mlir_module is not None
|
||||
|
@ -95,9 +101,11 @@ class FxImporterTestConfig(TestConfig):
|
|||
result: Trace = []
|
||||
for item in trace:
|
||||
prog = torch.export.export(artifact, tuple(item.inputs))
|
||||
module = jit(prog,
|
||||
module = jit(
|
||||
prog,
|
||||
func_name=artifact.__class__.__name__,
|
||||
output_type=self._output_type)
|
||||
output_type=self._output_type,
|
||||
)
|
||||
module = self._backend.compile(module)
|
||||
backend_module = self._backend.load(module)
|
||||
params = {
|
||||
|
@ -107,10 +115,10 @@ class FxImporterTestConfig(TestConfig):
|
|||
params_flat, params_spec = pytree.tree_flatten(params)
|
||||
params_flat = list(params_flat)
|
||||
with torch.no_grad():
|
||||
numpy_inputs = recursively_convert_to_numpy(params_flat +
|
||||
item.inputs)
|
||||
outputs = getattr(backend_module,
|
||||
artifact.__class__.__name__)(*numpy_inputs)
|
||||
numpy_inputs = recursively_convert_to_numpy(params_flat + item.inputs)
|
||||
outputs = getattr(backend_module, artifact.__class__.__name__)(
|
||||
*numpy_inputs
|
||||
)
|
||||
output = refine_result_type(outputs)
|
||||
if isinstance(output, (tuple, list)):
|
||||
user_output = []
|
||||
|
@ -120,7 +128,6 @@ class FxImporterTestConfig(TestConfig):
|
|||
user_output.append(val)
|
||||
output = tuple(user_output)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -23,20 +23,19 @@ class LazyTensorCoreTestConfig(TestConfig):
|
|||
lazy_backend._initialize()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
return program.to('lazy')
|
||||
return program.to("lazy")
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
|
||||
for item in trace:
|
||||
# We need to move all the inputs to the lazy device before running in LTC.
|
||||
lazy_inputs = tree_map(to_device('lazy'), item.inputs)
|
||||
lazy_inputs = tree_map(to_device("lazy"), item.inputs)
|
||||
output = getattr(artifact, item.symbol)(*lazy_inputs)
|
||||
cpu_outputs = tree_map(to_device('cpu'), output)
|
||||
cpu_outputs = tree_map(to_device("cpu"), output)
|
||||
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=cpu_outputs))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=cpu_outputs)
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
|
@ -24,6 +24,7 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
|||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the linalg-on-tensors abstraction level.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: LinalgOnTensorsBackend):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
@ -31,12 +32,11 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
|||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torchscript.compile(
|
||||
program, example_args, output_type="linalg-on-tensors")
|
||||
program, example_args, output_type="linalg-on-tensors"
|
||||
)
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
||||
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
backend_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
|
@ -45,7 +45,6 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
|||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||
output = recursively_convert_from_numpy(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -10,6 +10,7 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
|||
|
||||
class NativeTorchTestConfig(TestConfig):
|
||||
"""TestConfig that just runs the torch.nn.Module without compiling"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -23,7 +24,6 @@ class NativeTorchTestConfig(TestConfig):
|
|||
for item in trace:
|
||||
output = getattr(artifact, item.symbol)(*item.inputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -47,7 +47,7 @@ def convert_onnx(model, inputs):
|
|||
input_names = []
|
||||
dynamic_tensors = {}
|
||||
for (index, arg) in enumerate(inputs):
|
||||
shape = map(lambda d : d if d >= 0 else 1, arg.shape)
|
||||
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
||||
shape = tuple(shape)
|
||||
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
||||
|
||||
|
@ -56,24 +56,27 @@ def convert_onnx(model, inputs):
|
|||
|
||||
dynamic_dims = {}
|
||||
for (dimindex, dim) in enumerate(arg.shape):
|
||||
if (dim < 0):
|
||||
if dim < 0:
|
||||
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
||||
|
||||
if (dynamic_dims):
|
||||
if dynamic_dims:
|
||||
dynamic_tensors[input_name] = dynamic_dims
|
||||
|
||||
|
||||
examples=tuple(examples)
|
||||
torch.onnx.export(model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors)
|
||||
examples = tuple(examples)
|
||||
torch.onnx.export(
|
||||
model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors
|
||||
)
|
||||
buffer = buffer.getvalue()
|
||||
return import_onnx(buffer)
|
||||
|
||||
|
||||
class OnnxBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with ONNX.
|
||||
|
||||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the ONNX abstraction level.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
@ -85,8 +88,6 @@ class OnnxBackendTestConfig(TestConfig):
|
|||
compiled_module = self.backend.compile(onnx_module)
|
||||
return compiled_module
|
||||
|
||||
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
backend_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
|
@ -95,7 +96,6 @@ class OnnxBackendTestConfig(TestConfig):
|
|||
outputs = getattr(backend_module, "main_graph")(*numpy_inputs)
|
||||
output = recursively_convert_from_numpy(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -48,13 +48,13 @@ def refine_result_type(_result):
|
|||
def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_graph.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert len(
|
||||
node.args) == 1, "Output node must have a single argument"
|
||||
assert len(node.args) == 1, "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if node_arg != ():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to
|
||||
# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument
|
||||
# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it
|
||||
|
@ -69,7 +69,7 @@ def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule):
|
|||
for node in gm.graph.nodes:
|
||||
# Checks if we're calling a function (i.e:
|
||||
# torch.add)
|
||||
if node.op == 'call_function':
|
||||
if node.op == "call_function":
|
||||
# The target attribute is the function
|
||||
# that call_function calls.
|
||||
# call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {})
|
||||
|
@ -108,20 +108,24 @@ def jit(
|
|||
output_type = OutputType.get(output_type)
|
||||
if backend_legal_ops is not None:
|
||||
if output_type != OutputType.TORCH:
|
||||
raise Exception("`backend_legal_ops` is only valid with the "
|
||||
"`torch` output type")
|
||||
raise Exception(
|
||||
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
||||
)
|
||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
@make_boxed_compiler
|
||||
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
|
||||
_example_inputs: List[torch.Tensor]):
|
||||
def my_aot_autograd_backend(
|
||||
gm: torch.fx.GraphModule, _example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# Torch-MLIR does not support returning an empty tuple. The reason is
|
||||
# that both returning an empty tuple and returning `None` results in MLIR
|
||||
# functions that have as a return type `()`. In other words, there is no
|
||||
# way of differentiating between the two.
|
||||
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
|
||||
assert not _returns_empty_tuple(
|
||||
gm
|
||||
), "encountered graph that does not return anything"
|
||||
|
||||
scalarize_tensor_ops_on_scalars(gm)
|
||||
|
||||
|
@ -130,18 +134,24 @@ def jit(
|
|||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||
return gm
|
||||
|
||||
my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend,
|
||||
decompositions=_get_decomposition_table)
|
||||
my_backend = aot_autograd(
|
||||
fw_compiler=my_aot_autograd_backend, decompositions=_get_decomposition_table
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
set_model_name(model.__class__.__name__)
|
||||
torch._dynamo.reset()
|
||||
dynamo_f = dynamo.optimize(my_backend, nopython=True)(
|
||||
lambda method, *inputs: method(*inputs))
|
||||
dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]),
|
||||
*example_args)
|
||||
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
|
||||
" extra-library=" + extra_library_file_name + "}")
|
||||
lambda method, *inputs: method(*inputs)
|
||||
)
|
||||
dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]), *example_args)
|
||||
option_string = (
|
||||
"{backend-legal-ops="
|
||||
+ ",".join(backend_legal_ops)
|
||||
+ " extra-library="
|
||||
+ extra_library_file_name
|
||||
+ "}"
|
||||
)
|
||||
assert mlir_module is not None
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
|
@ -166,9 +176,7 @@ class TorchDynamoTestConfig(TestConfig):
|
|||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
module = jit(artifact,
|
||||
item.inputs,
|
||||
output_type="linalg-on-tensors")
|
||||
module = jit(artifact, item.inputs, output_type="linalg-on-tensors")
|
||||
module = self.backend.compile(module)
|
||||
backend_module = self.backend.load(module)
|
||||
params = {
|
||||
|
@ -178,13 +186,12 @@ class TorchDynamoTestConfig(TestConfig):
|
|||
params_flat, params_spec = pytree.tree_flatten(params)
|
||||
params_flat = list(params_flat)
|
||||
with torch.no_grad():
|
||||
numpy_inputs = recursively_convert_to_numpy(params_flat +
|
||||
item.inputs)
|
||||
outputs = getattr(backend_module,
|
||||
artifact.__class__.__name__)(*numpy_inputs)
|
||||
numpy_inputs = recursively_convert_to_numpy(params_flat + item.inputs)
|
||||
outputs = getattr(backend_module, artifact.__class__.__name__)(
|
||||
*numpy_inputs
|
||||
)
|
||||
output = refine_result_type(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
|||
|
||||
class TorchScriptTestConfig(TestConfig):
|
||||
"""TestConfig that runs the torch.nn.Module through TorchScript"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -26,11 +27,10 @@ class TorchScriptTestConfig(TestConfig):
|
|||
result: Trace = []
|
||||
for item in trace:
|
||||
attr = artifact
|
||||
for part in item.symbol.split('.'):
|
||||
for part in item.symbol.split("."):
|
||||
attr = getattr(attr, part)
|
||||
output = attr(*item.inputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -23,6 +23,7 @@ class TosaBackendTestConfig(TestConfig):
|
|||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the TOSA abstraction level.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: TosaBackend, use_make_fx: bool = False):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
@ -31,12 +32,11 @@ class TosaBackendTestConfig(TestConfig):
|
|||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torchscript.compile(
|
||||
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
|
||||
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx
|
||||
)
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
||||
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
backend_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
|
@ -45,7 +45,6 @@ class TosaBackendTestConfig(TestConfig):
|
|||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||
output = recursively_convert_from_numpy(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -27,6 +27,7 @@ def recursively_convert_to_numpy(o: Any):
|
|||
return o
|
||||
raise Exception(f"Unexpected Python function input: {o}")
|
||||
|
||||
|
||||
def recursively_convert_from_numpy(o: Any):
|
||||
if isinstance(o, np.ndarray):
|
||||
return torch.from_numpy(o)
|
||||
|
|
|
@ -24,8 +24,7 @@ def _make_single_op_gm(node) -> torch.fx.GraphModule:
|
|||
return torch.fx.GraphModule(torch.nn.Module(), g)
|
||||
|
||||
|
||||
def _identity_backend(gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
def _identity_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
"""A backend that just runs the given GraphModule as-is."""
|
||||
return gm
|
||||
|
||||
|
@ -51,6 +50,7 @@ def _make_last_use_map(g: torch.fx.Graph) -> Dict[torch.fx.Node, List[torch.fx.N
|
|||
# Lifetime just ended, so this is the last use.
|
||||
seen.add(use)
|
||||
last_use_map[user].append(use)
|
||||
|
||||
for node in reversed(g.nodes):
|
||||
assert not node.kwargs, "kwargs not supported yet"
|
||||
torch.fx.map_arg(node.args, lambda n: process_use(node, n))
|
||||
|
@ -84,9 +84,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
|||
Returns:
|
||||
A backend that compares the wrapped backend to `golden_backend`.
|
||||
"""
|
||||
|
||||
def wrapper(user_backend):
|
||||
def backend(gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
def backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
# We can ignore the example_inputs since we recompile in lockstep
|
||||
# anyway. TorchDynamo should already have appropriate guards in
|
||||
# place so that this doesn't change the compilation result.
|
||||
|
@ -96,7 +96,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
|||
|
||||
def compiled(*args):
|
||||
env = {}
|
||||
for placeholder, arg in zip([n for n in g.nodes if n.op == "placeholder"], args):
|
||||
for placeholder, arg in zip(
|
||||
[n for n in g.nodes if n.op == "placeholder"], args
|
||||
):
|
||||
env[placeholder] = arg
|
||||
# Evaluate the graph one node at a time, comparing the user and
|
||||
# golden backends. This code currently does not support
|
||||
|
@ -111,7 +113,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
|||
continue
|
||||
if node.op == "output":
|
||||
return torch.fx.map_arg(node.args[0], lambda n: env[n])
|
||||
assert node.op == "call_function", f"call_module/call_method not supported for {node} -- perhaps call make_simple_dynamo_backend first"
|
||||
assert (
|
||||
node.op == "call_function"
|
||||
), f"call_module/call_method not supported for {node} -- perhaps call make_simple_dynamo_backend first"
|
||||
assert not node.kwargs, "kwargs not supported yet"
|
||||
actual_args = torch.fx.map_arg(node.args, lambda n: env[n])
|
||||
if node not in backend_artifacts:
|
||||
|
@ -128,7 +132,8 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
|||
assert torch.allclose(user_result, golden_result), (
|
||||
f"User result {user_result} is not close to "
|
||||
f"golden result {golden_result} for "
|
||||
f"node {node} at {node.stack_trace}")
|
||||
f"node {node} at {node.stack_trace}"
|
||||
)
|
||||
# Clean up any tensors that are no longer needed.
|
||||
# TODO: Find a way to test this.
|
||||
# This was tested manually by printing the number of entries
|
||||
|
@ -136,6 +141,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
|||
for dead_node in last_use_map[node]:
|
||||
env.pop(dead_node)
|
||||
assert False, "not reached -- missing 'output' node"
|
||||
|
||||
return compiled
|
||||
|
||||
return backend
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -30,6 +30,7 @@ import traceback
|
|||
|
||||
import multiprocess as mp
|
||||
from multiprocess import set_start_method
|
||||
|
||||
try:
|
||||
set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
|
@ -38,9 +39,13 @@ except RuntimeError:
|
|||
|
||||
import torch
|
||||
|
||||
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
|
||||
Dict['TorchScriptValue',
|
||||
'TorchScriptValue'], torch.Tensor]
|
||||
TorchScriptValue = Union[
|
||||
int,
|
||||
float,
|
||||
List["TorchScriptValue"],
|
||||
Dict["TorchScriptValue", "TorchScriptValue"],
|
||||
torch.Tensor,
|
||||
]
|
||||
|
||||
|
||||
class TraceItem(NamedTuple):
|
||||
|
@ -91,9 +96,11 @@ def clone_torch_script_value(v: TorchScriptValue):
|
|||
# TODO: Figure out the root cause of the failure and fix properly.
|
||||
def clone_trace(trace: Trace) -> Trace:
|
||||
return [
|
||||
TraceItem(symbol=item.symbol,
|
||||
TraceItem(
|
||||
symbol=item.symbol,
|
||||
inputs=clone_torch_script_value(item.inputs),
|
||||
output=clone_torch_script_value(item.output))
|
||||
output=clone_torch_script_value(item.output),
|
||||
)
|
||||
for item in trace
|
||||
]
|
||||
|
||||
|
@ -101,7 +108,8 @@ def clone_trace(trace: Trace) -> Trace:
|
|||
# A type shared between the result of `TestConfig.compile` and the input
|
||||
# to `TestConfig.run`. Each backend will likely have a different definition of
|
||||
# this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||
|
||||
|
||||
class TestConfig(abc.ABC):
|
||||
"""The interface implemented by backends to run tests.
|
||||
|
@ -136,6 +144,7 @@ class TestConfig(abc.ABC):
|
|||
backend (compiler backend and runtime target) will have an arbitrarily
|
||||
wild and wonderful set of possible configurations that we cannot predict.
|
||||
"""
|
||||
|
||||
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
|
||||
# We can have a helper class LinalgOnTensorsBackendTestConfig which does that.
|
||||
@abc.abstractmethod
|
||||
|
@ -202,8 +211,8 @@ class TestUtils:
|
|||
|
||||
|
||||
class Test(NamedTuple):
|
||||
"""A description of a test as produced by the test frontend.
|
||||
"""
|
||||
"""A description of a test as produced by the test frontend."""
|
||||
|
||||
# Stable name for error reporting.
|
||||
#
|
||||
# This name's stability is also useful for backend, which want to
|
||||
|
@ -268,14 +277,20 @@ class _Tracer:
|
|||
inputs = [clone_torch_script_value(arg) for arg in args]
|
||||
output = self.__wrapped__(*args, **kwargs)
|
||||
self.__trace__.append(
|
||||
TraceItem(symbol=".".join(self.__property_base_path__),
|
||||
TraceItem(
|
||||
symbol=".".join(self.__property_base_path__),
|
||||
inputs=inputs,
|
||||
output=output))
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
def __getattr__(self, name):
|
||||
return _Tracer(getattr(self.__wrapped__, name),
|
||||
self.__property_base_path__ + [name], self.__trace__)
|
||||
return _Tracer(
|
||||
getattr(self.__wrapped__, name),
|
||||
self.__property_base_path__ + [name],
|
||||
self.__trace__,
|
||||
)
|
||||
|
||||
|
||||
def generate_golden_trace(test: Test) -> Trace:
|
||||
|
@ -297,40 +312,49 @@ def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any:
|
|||
print(f"Compiling {test.unique_name}...", file=sys.stderr)
|
||||
compiled = config.compile(test.program_factory())
|
||||
except Exception as e:
|
||||
return TestResult(unique_name=test.unique_name,
|
||||
return TestResult(
|
||||
unique_name=test.unique_name,
|
||||
compilation_error="".join(
|
||||
traceback.format_exception(
|
||||
type(e), e, e.__traceback__)),
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
),
|
||||
runtime_error=None,
|
||||
trace=None,
|
||||
golden_trace=None)
|
||||
golden_trace=None,
|
||||
)
|
||||
try:
|
||||
if verbose:
|
||||
print(f"Running {test.unique_name}...", file=sys.stderr)
|
||||
trace = config.run(compiled, golden_trace)
|
||||
except Exception as e:
|
||||
return TestResult(unique_name=test.unique_name,
|
||||
return TestResult(
|
||||
unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error="".join(
|
||||
traceback.format_exception(
|
||||
type(e), e, e.__traceback__)),
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
),
|
||||
trace=None,
|
||||
golden_trace=None)
|
||||
return TestResult(unique_name=test.unique_name,
|
||||
golden_trace=None,
|
||||
)
|
||||
return TestResult(
|
||||
unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error=None,
|
||||
trace=clone_trace(trace),
|
||||
golden_trace=clone_trace(golden_trace))
|
||||
golden_trace=clone_trace(golden_trace),
|
||||
)
|
||||
|
||||
|
||||
def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]:
|
||||
def run_tests(
|
||||
tests: List[Test], config: TestConfig, sequential=False, verbose=False
|
||||
) -> List[TestResult]:
|
||||
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
||||
num_processes = min(int(mp.cpu_count() * 0.8) + 1, len(tests))
|
||||
try:
|
||||
env_concurrency = int(os.getenv("TORCH_MLIR_TEST_CONCURRENCY", "0"))
|
||||
except ValueError as e:
|
||||
raise ValueError("Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: "
|
||||
"Expected integer.") from e
|
||||
raise ValueError(
|
||||
"Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: " "Expected integer."
|
||||
) from e
|
||||
if env_concurrency > 0:
|
||||
num_processes = min(num_processes, env_concurrency)
|
||||
|
||||
|
@ -374,10 +398,11 @@ def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=F
|
|||
TestResult(
|
||||
unique_name=aborted_test_name,
|
||||
compilation_error=None,
|
||||
runtime_error=
|
||||
"Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
|
||||
runtime_error="Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
|
||||
trace=None,
|
||||
golden_trace=None) for aborted_test_name in aborted_tests
|
||||
golden_trace=None,
|
||||
)
|
||||
for aborted_test_name in aborted_tests
|
||||
]
|
||||
results.extend(aborted_tests_results)
|
||||
results.sort(key=lambda result: result.unique_name)
|
||||
|
|
|
@ -13,12 +13,12 @@ from torch_mlir.ir import Module
|
|||
# A type shared between the result of `LinalgOnTensorsBackend.compile` and the
|
||||
# input to `LinalgOnTensorsBackend.load`. Each backend will likely have a
|
||||
# different definition of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
# that uniformly translates the `x.method(...)` interface expected of
|
||||
# Torch modules into appropriate lower-level operations.
|
||||
Invoker = TypeVar('Invoker')
|
||||
Invoker = TypeVar("Invoker")
|
||||
|
||||
|
||||
class LinalgOnTensorsBackend(abc.ABC):
|
||||
|
@ -27,6 +27,7 @@ class LinalgOnTensorsBackend(abc.ABC):
|
|||
Backends are recommended to raise meaningful exceptions in case of error,
|
||||
ideally with easy reproduction instructions.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
|
|
@ -22,10 +22,20 @@ __all__ = [
|
|||
|
||||
def assert_arg_type_is_supported(ty):
|
||||
SUPPORTED = [
|
||||
np.float16, np.float32, np.float64, np.uint8, np.int8, np.int32,
|
||||
np.int64, np.bool_, np.complex64, np.complex128
|
||||
np.float16,
|
||||
np.float32,
|
||||
np.float64,
|
||||
np.uint8,
|
||||
np.int8,
|
||||
np.int32,
|
||||
np.int64,
|
||||
np.bool_,
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
]
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
|
||||
assert (
|
||||
ty in SUPPORTED
|
||||
), f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
|
||||
|
||||
|
||||
memref_type_to_np_dtype = {
|
||||
|
@ -37,14 +47,14 @@ memref_type_to_np_dtype = {
|
|||
"mri32": np.int32,
|
||||
"mri64": np.int64,
|
||||
"mrc32": np.complex64,
|
||||
"mrc64": np.complex128
|
||||
"mrc64": np.complex128,
|
||||
}
|
||||
elemental_type_to_ctype = {
|
||||
"i1": ctypes.c_bool,
|
||||
"i8": ctypes.c_byte,
|
||||
"i64": ctypes.c_int,
|
||||
"f32": ctypes.c_float,
|
||||
"f64": ctypes.c_double
|
||||
"f64": ctypes.c_double,
|
||||
}
|
||||
|
||||
CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_"
|
||||
|
@ -56,7 +66,7 @@ def get_return_funcs(module):
|
|||
with module.context:
|
||||
for func in module.body:
|
||||
# Returns strings of the form `"refbackend.."` so `"` is deleted.
|
||||
func_name = str(func.attributes["sym_name"]).replace('"', '')
|
||||
func_name = str(func.attributes["sym_name"]).replace('"', "")
|
||||
if func_name[:return_prefix_len] == CONSUME_RETURN_FUNC_PREFIX:
|
||||
return_funcs.append(func_name)
|
||||
|
||||
|
@ -79,7 +89,6 @@ def get_ctype_func(func_name):
|
|||
|
||||
|
||||
class RefBackendInvoker:
|
||||
|
||||
def __init__(self, module):
|
||||
self.ee = ExecutionEngine(module)
|
||||
self.result = None
|
||||
|
@ -90,27 +99,29 @@ class RefBackendInvoker:
|
|||
ctype_wrapper, ret_types = get_ctype_func(ret_func)
|
||||
|
||||
def consume_return_funcs(*args):
|
||||
self.result = tuple([
|
||||
arg if type in elemental_type_to_ctype
|
||||
self.result = tuple(
|
||||
[
|
||||
arg
|
||||
if type in elemental_type_to_ctype
|
||||
else unranked_memref_to_numpy(
|
||||
arg, memref_type_to_np_dtype[type])
|
||||
arg, memref_type_to_np_dtype[type]
|
||||
)
|
||||
for arg, type in zip(args, ret_types)
|
||||
])
|
||||
]
|
||||
)
|
||||
if len(self.result) == 1:
|
||||
self.result = self.result[0]
|
||||
|
||||
self.ee.register_runtime(ret_func,
|
||||
ctype_wrapper(consume_return_funcs))
|
||||
self.ee.register_runtime(ret_func, ctype_wrapper(consume_return_funcs))
|
||||
|
||||
def __getattr__(self, function_name: str):
|
||||
|
||||
def invoke(*args):
|
||||
ffi_args = []
|
||||
for arg in args:
|
||||
assert_arg_type_is_supported(arg.dtype)
|
||||
ffi_args.append(
|
||||
ctypes.pointer(
|
||||
ctypes.pointer(get_unranked_memref_descriptor(arg))))
|
||||
ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(arg)))
|
||||
)
|
||||
|
||||
self.ee.invoke(function_name, *ffi_args)
|
||||
result = self.result
|
||||
|
@ -121,7 +132,10 @@ class RefBackendInvoker:
|
|||
return invoke
|
||||
|
||||
|
||||
LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
||||
LOWERING_PIPELINE = (
|
||||
"builtin.module("
|
||||
+ ",".join(
|
||||
[
|
||||
"func.func(refback-generalize-tensor-pad)",
|
||||
"func.func(refback-generalize-tensor-concat)",
|
||||
# Apply some optimizations. It would be great if MLIR had more useful
|
||||
|
@ -181,7 +195,10 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
|||
"convert-cf-to-llvm",
|
||||
"convert-complex-to-llvm",
|
||||
"reconcile-unrealized-casts",
|
||||
]) + ")"
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
||||
|
@ -204,7 +221,8 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
|||
passed to `load`.
|
||||
"""
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module, LOWERING_PIPELINE,
|
||||
imported_module,
|
||||
LOWERING_PIPELINE,
|
||||
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",
|
||||
enable_ir_printing=False,
|
||||
)
|
||||
|
|
|
@ -13,12 +13,12 @@ from torch_mlir.ir import Module
|
|||
# A type shared between the result of `OnnxBackend.compile` and the
|
||||
# input to `OnnxBackend.load`. Each backend will likely have a
|
||||
# different definition of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
# that uniformly translates the `x.method(...)` interface expected of
|
||||
# Torch modules into appropriate lower-level operations.
|
||||
Invoker = TypeVar('Invoker')
|
||||
Invoker = TypeVar("Invoker")
|
||||
|
||||
|
||||
class OnnxBackend(abc.ABC):
|
||||
|
@ -27,6 +27,7 @@ class OnnxBackend(abc.ABC):
|
|||
Backends are recommended to raise meaningful exceptions in case of error,
|
||||
ideally with easy reproduction instructions.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
|
|
@ -12,7 +12,9 @@ from torch_mlir.compiler_utils import (
|
|||
from torch_mlir.ir import *
|
||||
from torch_mlir.passmanager import *
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
|
||||
from .abc import OnnxBackend
|
||||
|
||||
|
@ -22,9 +24,11 @@ __all__ = [
|
|||
|
||||
# The pipeline of func.func passes that lower the ONNX backend contract to the
|
||||
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||
ONNX_TO_TORCH_FUNC_PIPELINE = ",".join([
|
||||
ONNX_TO_TORCH_FUNC_PIPELINE = ",".join(
|
||||
[
|
||||
"convert-torch-onnx-to-torch",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
||||
|
@ -50,9 +54,14 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
|||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
||||
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract")
|
||||
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
|
||||
)
|
||||
|
||||
backend_legal_ops = ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int']
|
||||
backend_legal_ops = [
|
||||
"aten.flatten.using_ints",
|
||||
"aten.adaptive_avg_pool1d",
|
||||
"aten.unflatten.int",
|
||||
]
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
|
@ -60,7 +69,9 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
|||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
|
||||
imported_module = lower_mlir_module(
|
||||
False, OutputType.LINALG_ON_TENSORS, imported_module
|
||||
)
|
||||
compiled_module = self.refbackend.compile(imported_module)
|
||||
return compiled_module
|
||||
|
||||
|
|
|
@ -23,18 +23,23 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
|||
test's `program_factory` is taken from `module_factory`, and the
|
||||
`program_invoker` is the decorated function.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
# Ensure that there are no duplicate names in the global test registry.
|
||||
if f.__name__ in _SEEN_UNIQUE_NAMES:
|
||||
raise Exception(
|
||||
f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name.")
|
||||
f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name."
|
||||
)
|
||||
_SEEN_UNIQUE_NAMES.add(f.__name__)
|
||||
|
||||
# Store the test in the registry.
|
||||
GLOBAL_TEST_REGISTRY.append(
|
||||
Test(unique_name=f.__name__,
|
||||
Test(
|
||||
unique_name=f.__name__,
|
||||
program_factory=module_factory,
|
||||
program_invoker=f))
|
||||
program_invoker=f,
|
||||
)
|
||||
)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -19,6 +19,7 @@ from .framework import TestResult, TraceItem
|
|||
|
||||
class TensorSummary:
|
||||
"""A summary of a tensor's contents."""
|
||||
|
||||
def __init__(self, tensor):
|
||||
self.min = torch.min(tensor.type(torch.float64))
|
||||
self.max = torch.max(tensor.type(torch.float64))
|
||||
|
@ -27,7 +28,7 @@ class TensorSummary:
|
|||
self.dtype = tensor.dtype
|
||||
|
||||
def __str__(self):
|
||||
return f'Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'
|
||||
return f"Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}"
|
||||
|
||||
|
||||
class ErrorContext:
|
||||
|
@ -35,6 +36,7 @@ class ErrorContext:
|
|||
|
||||
This is useful for tracking errors across multiple levels of detail.
|
||||
"""
|
||||
|
||||
def __init__(self, contexts: List[str]):
|
||||
self.contexts = contexts
|
||||
|
||||
|
@ -47,17 +49,16 @@ class ErrorContext:
|
|||
return ErrorContext([])
|
||||
|
||||
def chain(self, additional_context: str):
|
||||
"""Chain an additional context onto the current error context.
|
||||
"""
|
||||
"""Chain an additional context onto the current error context."""
|
||||
return ErrorContext(self.contexts + [additional_context])
|
||||
|
||||
def format_error(self, s: str):
|
||||
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
|
||||
return "@ " + "\n@ ".join(self.contexts) + "\n" + "ERROR: " + s
|
||||
|
||||
|
||||
class ValueReport:
|
||||
"""A report for a single value processed by the program.
|
||||
"""
|
||||
"""A report for a single value processed by the program."""
|
||||
|
||||
def __init__(self, value, golden_value, context: ErrorContext):
|
||||
self.value = value
|
||||
self.golden_value = golden_value
|
||||
|
@ -70,7 +71,7 @@ class ValueReport:
|
|||
return len(self.failure_reasons) != 0
|
||||
|
||||
def error_str(self):
|
||||
return '\n'.join(self.failure_reasons)
|
||||
return "\n".join(self.failure_reasons)
|
||||
|
||||
def _evaluate_outcome(self):
|
||||
value, golden = self.value, self.golden_value
|
||||
|
@ -80,37 +81,37 @@ class ValueReport:
|
|||
golden = golden[0]
|
||||
if isinstance(golden, float):
|
||||
if not isinstance(value, float):
|
||||
return self._record_mismatch_type_failure('float', value)
|
||||
return self._record_mismatch_type_failure("float", value)
|
||||
if abs(value - golden) / golden > 1e-4:
|
||||
return self._record_failure(
|
||||
f'value ({value!r}) is not close to golden value ({golden!r})'
|
||||
f"value ({value!r}) is not close to golden value ({golden!r})"
|
||||
)
|
||||
return
|
||||
if isinstance(golden, int):
|
||||
if not isinstance(value, int):
|
||||
return self._record_mismatch_type_failure('int', value)
|
||||
return self._record_mismatch_type_failure("int", value)
|
||||
if value != golden:
|
||||
return self._record_failure(
|
||||
f'value ({value!r}) is not equal to golden value ({golden!r})'
|
||||
f"value ({value!r}) is not equal to golden value ({golden!r})"
|
||||
)
|
||||
return
|
||||
if isinstance(golden, str):
|
||||
if not isinstance(value, str):
|
||||
return self._record_mismatch_type_failure('str', value)
|
||||
return self._record_mismatch_type_failure("str", value)
|
||||
if value != golden:
|
||||
return self._record_failure(
|
||||
f'value ({value!r}) is not equal to golden value ({golden!r})'
|
||||
f"value ({value!r}) is not equal to golden value ({golden!r})"
|
||||
)
|
||||
return
|
||||
if isinstance(golden, tuple):
|
||||
if not isinstance(value, tuple):
|
||||
return self._record_mismatch_type_failure('tuple', value)
|
||||
return self._record_mismatch_type_failure("tuple", value)
|
||||
if len(value) != len(golden):
|
||||
return self._record_failure(
|
||||
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
|
||||
f"value ({len(value)!r}) is not equal to golden value ({len(golden)!r})"
|
||||
)
|
||||
reports = [
|
||||
ValueReport(v, g, self.context.chain(f'tuple element {i}'))
|
||||
ValueReport(v, g, self.context.chain(f"tuple element {i}"))
|
||||
for i, (v, g) in enumerate(zip(value, golden))
|
||||
]
|
||||
for report in reports:
|
||||
|
@ -119,13 +120,13 @@ class ValueReport:
|
|||
return
|
||||
if isinstance(golden, list):
|
||||
if not isinstance(value, list):
|
||||
return self._record_mismatch_type_failure('list', value)
|
||||
return self._record_mismatch_type_failure("list", value)
|
||||
if len(value) != len(golden):
|
||||
return self._record_failure(
|
||||
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
|
||||
f"value ({len(value)!r}) is not equal to golden value ({len(golden)!r})"
|
||||
)
|
||||
reports = [
|
||||
ValueReport(v, g, self.context.chain(f'list element {i}'))
|
||||
ValueReport(v, g, self.context.chain(f"list element {i}"))
|
||||
for i, (v, g) in enumerate(zip(value, golden))
|
||||
]
|
||||
for report in reports:
|
||||
|
@ -134,16 +135,19 @@ class ValueReport:
|
|||
return
|
||||
if isinstance(golden, dict):
|
||||
if not isinstance(value, dict):
|
||||
return self._record_mismatch_type_failure('dict', value)
|
||||
return self._record_mismatch_type_failure("dict", value)
|
||||
gkeys = list(sorted(golden.keys()))
|
||||
vkeys = list(sorted(value.keys()))
|
||||
if gkeys != vkeys:
|
||||
return self._record_failure(
|
||||
f'dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})'
|
||||
f"dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})"
|
||||
)
|
||||
reports = [
|
||||
ValueReport(value[k], golden[k],
|
||||
self.context.chain(f'dict element at key {k!r}'))
|
||||
ValueReport(
|
||||
value[k],
|
||||
golden[k],
|
||||
self.context.chain(f"dict element at key {k!r}"),
|
||||
)
|
||||
for k in gkeys
|
||||
]
|
||||
for report in reports:
|
||||
|
@ -152,40 +156,42 @@ class ValueReport:
|
|||
return
|
||||
if isinstance(golden, torch.Tensor):
|
||||
if not isinstance(value, torch.Tensor):
|
||||
return self._record_mismatch_type_failure('torch.Tensor', value)
|
||||
return self._record_mismatch_type_failure("torch.Tensor", value)
|
||||
|
||||
if value.shape != golden.shape:
|
||||
return self._record_failure(
|
||||
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
|
||||
f"shape ({value.shape}) is not equal to golden shape ({golden.shape})"
|
||||
)
|
||||
if value.dtype != golden.dtype:
|
||||
return self._record_failure(
|
||||
f'dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
|
||||
f"dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})"
|
||||
)
|
||||
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
||||
if not torch.allclose(
|
||||
value, golden, rtol=1e-03, atol=1e-07, equal_nan=True
|
||||
):
|
||||
return self._record_failure(
|
||||
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
|
||||
f"value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})"
|
||||
)
|
||||
return
|
||||
return self._record_failure(
|
||||
f'unexpected golden value of type `{golden.__class__.__name__}`')
|
||||
f"unexpected golden value of type `{golden.__class__.__name__}`"
|
||||
)
|
||||
|
||||
def _record_failure(self, s: str):
|
||||
self.failure_reasons.append(self.context.format_error(s))
|
||||
|
||||
def _record_mismatch_type_failure(self, expected: str, actual: Any):
|
||||
self._record_failure(
|
||||
f'expected a value of type `{expected}` but got `{actual.__class__.__name__}`'
|
||||
f"expected a value of type `{expected}` but got `{actual.__class__.__name__}`"
|
||||
)
|
||||
|
||||
|
||||
|
||||
class TraceItemReport:
|
||||
"""A report for a single trace item."""
|
||||
|
||||
failure_reasons: List[str]
|
||||
|
||||
def __init__(self, item: TraceItem, golden_item: TraceItem,
|
||||
context: ErrorContext):
|
||||
def __init__(self, item: TraceItem, golden_item: TraceItem, context: ErrorContext):
|
||||
self.item = item
|
||||
self.golden_item = golden_item
|
||||
self.context = context
|
||||
|
@ -197,36 +203,43 @@ class TraceItemReport:
|
|||
return len(self.failure_reasons) != 0
|
||||
|
||||
def error_str(self):
|
||||
return '\n'.join(self.failure_reasons)
|
||||
return "\n".join(self.failure_reasons)
|
||||
|
||||
def _evaluate_outcome(self):
|
||||
if self.item.symbol != self.golden_item.symbol:
|
||||
self.failure_reasons.append(
|
||||
self.context.format_error(
|
||||
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
|
||||
))
|
||||
)
|
||||
)
|
||||
if len(self.item.inputs) != len(self.golden_item.inputs):
|
||||
self.failure_reasons.append(
|
||||
self.context.format_error(
|
||||
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
|
||||
))
|
||||
)
|
||||
)
|
||||
for i, (input, golden_input) in enumerate(
|
||||
zip(self.item.inputs, self.golden_item.inputs)):
|
||||
zip(self.item.inputs, self.golden_item.inputs)
|
||||
):
|
||||
value_report = ValueReport(
|
||||
input, golden_input,
|
||||
self.context.chain(
|
||||
f'input #{i} of call to "{self.item.symbol}"'))
|
||||
input,
|
||||
golden_input,
|
||||
self.context.chain(f'input #{i} of call to "{self.item.symbol}"'),
|
||||
)
|
||||
if value_report.failed:
|
||||
self.failure_reasons.append(value_report.error_str())
|
||||
value_report = ValueReport(
|
||||
self.item.output, self.golden_item.output,
|
||||
self.context.chain(f'output of call to "{self.item.symbol}"'))
|
||||
self.item.output,
|
||||
self.golden_item.output,
|
||||
self.context.chain(f'output of call to "{self.item.symbol}"'),
|
||||
)
|
||||
if value_report.failed:
|
||||
self.failure_reasons.append(value_report.error_str())
|
||||
|
||||
|
||||
class SingleTestReport:
|
||||
"""A report for a single test."""
|
||||
|
||||
item_reports: Optional[List[TraceItemReport]]
|
||||
|
||||
def __init__(self, result: TestResult, context: ErrorContext):
|
||||
|
@ -236,12 +249,15 @@ class SingleTestReport:
|
|||
if result.compilation_error is None and result.runtime_error is None:
|
||||
self.item_reports = []
|
||||
for i, (item, golden_item) in enumerate(
|
||||
zip(result.trace, result.golden_trace)):
|
||||
zip(result.trace, result.golden_trace)
|
||||
):
|
||||
self.item_reports.append(
|
||||
TraceItemReport(
|
||||
item, golden_item,
|
||||
context.chain(
|
||||
f'trace item #{i} - call to "{item.symbol}"')))
|
||||
item,
|
||||
golden_item,
|
||||
context.chain(f'trace item #{i} - call to "{item.symbol}"'),
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def failed(self):
|
||||
|
@ -256,19 +272,21 @@ class SingleTestReport:
|
|||
f = io.StringIO()
|
||||
p = lambda *x: print(*x, file=f)
|
||||
if self.result.compilation_error is not None:
|
||||
return 'Compilation error: ' + self.result.compilation_error
|
||||
return "Compilation error: " + self.result.compilation_error
|
||||
elif self.result.runtime_error is not None:
|
||||
return 'Runtime error: ' + self.result.runtime_error
|
||||
return "Runtime error: " + self.result.runtime_error
|
||||
for report in self.item_reports:
|
||||
if report.failed:
|
||||
p(report.error_str())
|
||||
return f.getvalue()
|
||||
|
||||
|
||||
def report_results(results: List[TestResult],
|
||||
def report_results(
|
||||
results: List[TestResult],
|
||||
expected_failures: Set[str],
|
||||
verbose: bool = False,
|
||||
config: str = ""):
|
||||
config: str = "",
|
||||
):
|
||||
"""Print a basic error report summarizing various TestResult's.
|
||||
|
||||
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
|
||||
|
@ -293,49 +311,50 @@ def report_results(results: List[TestResult],
|
|||
if expected_failure:
|
||||
if report.failed:
|
||||
print(f'XFAIL - "{result.unique_name}"')
|
||||
results_by_outcome['XFAIL'].append((result, report))
|
||||
results_by_outcome["XFAIL"].append((result, report))
|
||||
else:
|
||||
print(f'XPASS - "{result.unique_name}"')
|
||||
results_by_outcome['XPASS'].append((result, report))
|
||||
results_by_outcome["XPASS"].append((result, report))
|
||||
else:
|
||||
if not report.failed:
|
||||
print(f'PASS - "{result.unique_name}"')
|
||||
results_by_outcome['PASS'].append((result, report))
|
||||
results_by_outcome["PASS"].append((result, report))
|
||||
else:
|
||||
print(f'FAIL - "{result.unique_name}"')
|
||||
results_by_outcome['FAIL'].append((result, report))
|
||||
results_by_outcome["FAIL"].append((result, report))
|
||||
|
||||
OUTCOME_MEANINGS = collections.OrderedDict()
|
||||
OUTCOME_MEANINGS['PASS'] = 'Passed'
|
||||
OUTCOME_MEANINGS['FAIL'] = 'Failed'
|
||||
OUTCOME_MEANINGS['XFAIL'] = 'Expectedly Failed'
|
||||
OUTCOME_MEANINGS['XPASS'] = 'Unexpectedly Passed'
|
||||
OUTCOME_MEANINGS["PASS"] = "Passed"
|
||||
OUTCOME_MEANINGS["FAIL"] = "Failed"
|
||||
OUTCOME_MEANINGS["XFAIL"] = "Expectedly Failed"
|
||||
OUTCOME_MEANINGS["XPASS"] = "Unexpectedly Passed"
|
||||
|
||||
had_unexpected_results = len(results_by_outcome['FAIL']) != 0 or len(
|
||||
results_by_outcome['XPASS']) != 0
|
||||
had_unexpected_results = (
|
||||
len(results_by_outcome["FAIL"]) != 0 or len(results_by_outcome["XPASS"]) != 0
|
||||
)
|
||||
|
||||
if had_unexpected_results:
|
||||
print(f'\nUnexpected outcome summary: ({config})')
|
||||
print(f"\nUnexpected outcome summary: ({config})")
|
||||
|
||||
# For FAIL and XPASS (unexpected outcomes), print a summary.
|
||||
for outcome, results in results_by_outcome.items():
|
||||
# PASS and XFAIL are "good"/"successful" outcomes.
|
||||
if outcome == 'PASS' or outcome == 'XFAIL':
|
||||
if outcome == "PASS" or outcome == "XFAIL":
|
||||
continue
|
||||
# If there is nothing to report, be quiet.
|
||||
if len(results) == 0:
|
||||
continue
|
||||
print(f'\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests')
|
||||
print(f"\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests")
|
||||
for result, report in results:
|
||||
print(f' {outcome} - "{result.unique_name}"')
|
||||
# If the test failed, print the error message.
|
||||
if outcome == 'FAIL' and verbose:
|
||||
print(textwrap.indent(report.error_str(), ' ' * 8))
|
||||
if outcome == "FAIL" and verbose:
|
||||
print(textwrap.indent(report.error_str(), " " * 8))
|
||||
|
||||
# Print a summary for easy scanning.
|
||||
print('\nSummary:')
|
||||
print("\nSummary:")
|
||||
|
||||
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
|
||||
for key in ["PASS", "FAIL", "XFAIL", "XPASS"]:
|
||||
if results_by_outcome[key]:
|
||||
print(f' {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}')
|
||||
print(f" {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}")
|
||||
return had_unexpected_results
|
||||
|
|
|
@ -7,7 +7,9 @@ from torch_mlir.ir import *
|
|||
from torch_mlir.passmanager import *
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
|
||||
from .abc import StablehloBackend
|
||||
|
||||
|
@ -17,11 +19,13 @@ __all__ = [
|
|||
|
||||
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
|
||||
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
|
||||
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join(
|
||||
[
|
||||
"func.func(stablehlo-aggressive-simplification)",
|
||||
"stablehlo-legalize-to-linalg",
|
||||
"canonicalize"
|
||||
])
|
||||
"canonicalize",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
||||
|
@ -47,7 +51,8 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
|||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
|
||||
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract")
|
||||
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract",
|
||||
)
|
||||
|
||||
return self.refbackend.compile(imported_module)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
}
|
||||
|
||||
|
||||
def register_all_tests():
|
||||
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
|
||||
# Side-effecting import statements.
|
||||
|
|
|
@ -17,13 +17,15 @@ class ArangeIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeIntModule())
|
||||
def ArangeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -34,13 +36,15 @@ class ArangeFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(5.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeFloatModule())
|
||||
def ArangeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -51,31 +55,37 @@ class ArangeZeroElementOutputModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
|
||||
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArangeStartIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(0, 5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartIntModule())
|
||||
def ArangeStartIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -86,13 +96,15 @@ class ArangeStartFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(0.0, 5.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
|
||||
def ArangeStartFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -103,13 +115,15 @@ class ArangeNegativeStartIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(-10, 5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
|
||||
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -120,31 +134,37 @@ class ArangeNegativeStartFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(-1.4, 5.7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
|
||||
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArangeStartStepIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(0, 5, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
|
||||
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -155,13 +175,15 @@ class ArangeStartStepFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(-1, 5, 1.3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
|
||||
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -172,13 +194,15 @@ class ArangeStartNegativeStepIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(10, 1, -2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
|
||||
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -189,31 +213,37 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(-1, -15, -3.4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
|
||||
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArangeDtypeFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(-1, 15, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
|
||||
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -224,110 +254,137 @@ class ArangeDtypeIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(0.2, 5.0, dtype=torch.int64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
|
||||
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
||||
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArangeStartOutModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([12], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.arange(start=0, end=12, out=x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartOutModule())
|
||||
def ArangeStartOutModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(12).to(torch.int64))
|
||||
|
||||
|
||||
class ArangeStartOutViewModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.arange(start=1, end=13, out=x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
|
||||
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(3, 4).to(torch.int64))
|
||||
|
||||
|
||||
class ArangeStartOutDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([12], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.arange(start=1.1, end=13.1, out=x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
||||
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(12).to(torch.int64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LinspaceModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 10)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceModule())
|
||||
def LinspaceModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class LinspaceDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64)
|
||||
|
||||
|
@ -336,47 +393,59 @@ class LinspaceDtypeModule(torch.nn.Module):
|
|||
def LinspaceDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class LinspaceEmptyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceEmptyModule())
|
||||
def LinspaceEmptyModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class LinspaceOneSizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceOneSizeModule())
|
||||
def LinspaceOneSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class LinspaceTwoSizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
||||
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -387,12 +456,16 @@ class PrimsIotaModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu',
|
||||
requires_grad=False)
|
||||
return torch.ops.prims.iota(
|
||||
77, start=0, step=1, dtype=torch.int64, device="cpu", requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: PrimsIotaModule())
|
||||
def PrimsIotaModule_basic(module, tu: TestUtils):
|
||||
|
|
|
@ -13,21 +13,21 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class SoftmaxBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, output):
|
||||
return torch.ops.aten._softmax_backward_data(grad_output,
|
||||
output,
|
||||
dim=1,
|
||||
input_dtype=6)
|
||||
return torch.ops.aten._softmax_backward_data(
|
||||
grad_output, output, dim=1, input_dtype=6
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
|
||||
|
@ -37,16 +37,17 @@ def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
class TanhBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_out, output):
|
||||
return torch.ops.aten.tanh_backward(grad_out, output)
|
||||
|
||||
|
@ -58,40 +59,46 @@ def TanhBackward_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class HardtanhBackwardModule(torch.nn.Module):
|
||||
|
||||
class HardtanhBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_out, input):
|
||||
return torch.ops.aten.hardtanh_backward(grad_out, input, min_val=0.2, max_val=0.5)
|
||||
return torch.ops.aten.hardtanh_backward(
|
||||
grad_out, input, min_val=0.2, max_val=0.5
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: HardtanhBackwardModule())
|
||||
def HardtanhBackward_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 20), tu.rand(10, 20))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ConvolutionBackwardModule2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
|
@ -104,27 +111,29 @@ class ConvolutionBackwardModule2D(torch.nn.Module):
|
|||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
output_mask=[True, True, True],
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2D())
|
||||
def ConvolutionBackwardModule2D_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6),
|
||||
tu.rand(2, 2, 2, 2))
|
||||
module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6), tu.rand(2, 2, 2, 2))
|
||||
|
||||
|
||||
class ConvolutionBackwardModule2DStatic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 4, 64, 64], torch.float32, True),
|
||||
([1, 320, 64, 64], torch.float32, True),
|
||||
([4, 320, 3, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
|
@ -137,28 +146,31 @@ class ConvolutionBackwardModule2DStatic(torch.nn.Module):
|
|||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
output_mask=[True, True, True],
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStatic())
|
||||
def ConvolutionBackwardModule2DStatic_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(1, 4, 64, 64), tu.rand(1, 320, 64, 64),
|
||||
tu.rand(4, 320, 3, 3))
|
||||
module.forward(
|
||||
tu.rand(1, 4, 64, 64), tu.rand(1, 320, 64, 64), tu.rand(4, 320, 3, 3)
|
||||
)
|
||||
|
||||
|
||||
class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
|
@ -171,28 +183,29 @@ class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
|||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
output_mask=[True, True, True],
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DPadded())
|
||||
def ConvolutionBackwardModule2DPadded_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6),
|
||||
tu.rand(2, 2, 3, 3))
|
||||
module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6), tu.rand(2, 2, 3, 3))
|
||||
|
||||
|
||||
class ConvolutionBackwardModule2DStrided(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 2, 4, 4], torch.float32, True),
|
||||
([1, 2, 8, 8], torch.float32, True),
|
||||
([2, 2, 3, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
|
@ -205,30 +218,31 @@ class ConvolutionBackwardModule2DStrided(torch.nn.Module):
|
|||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
output_mask=[True, True, True],
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStrided())
|
||||
def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8),
|
||||
tu.rand(2, 2, 3, 3))
|
||||
module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), tu.rand(2, 2, 3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GeluBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.gelu_backward(grad, input)
|
||||
|
||||
|
@ -239,21 +253,21 @@ def GeluBackwardModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class LogSoftmaxBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, output):
|
||||
return torch.ops.aten._log_softmax_backward_data(grad_output,
|
||||
output,
|
||||
dim=1,
|
||||
input_dtype=6)
|
||||
return torch.ops.aten._log_softmax_backward_data(
|
||||
grad_output, output, dim=1, input_dtype=6
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LogSoftmaxBackwardModule())
|
||||
|
@ -265,18 +279,21 @@ def LogSoftmaxBackwardModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class LeakyReluBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.leaky_relu_backward(grad, input, negative_slope=0.1, self_is_result=False)
|
||||
return torch.ops.aten.leaky_relu_backward(
|
||||
grad, input, negative_slope=0.1, self_is_result=False
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LeakyReluBackwardModule())
|
||||
|
@ -285,18 +302,21 @@ def LeakyReluBackwardModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class LeakyReluBackwardStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4, 5], torch.float32, True),
|
||||
([3, 4, 5], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.leaky_relu_backward(grad, input, negative_slope=0.1, self_is_result=False)
|
||||
return torch.ops.aten.leaky_relu_backward(
|
||||
grad, input, negative_slope=0.1, self_is_result=False
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule())
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToIntZeroRank(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
|
||||
|
@ -28,17 +31,21 @@ class TensorToIntZeroRank(torch.nn.Module):
|
|||
def TensorToIntZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
|
||||
|
@ -47,17 +54,21 @@ class TensorToInt(torch.nn.Module):
|
|||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 1, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToFloatZeroRank(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return float(x)
|
||||
|
||||
|
@ -66,17 +77,21 @@ class TensorToFloatZeroRank(torch.nn.Module):
|
|||
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand().to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return float(x)
|
||||
|
||||
|
@ -85,17 +100,21 @@ class TensorToFloat(torch.nn.Module):
|
|||
def TensorToFloat_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToBoolZeroRank(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.bool, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return bool(x)
|
||||
|
||||
|
@ -104,17 +123,21 @@ class TensorToBoolZeroRank(torch.nn.Module):
|
|||
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(1, dtype=torch.bool))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToBool(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return bool(x)
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -13,15 +13,13 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TorchPrimLoopForLikeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True)
|
||||
])
|
||||
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
sum = 0
|
||||
|
@ -34,20 +32,18 @@ class TorchPrimLoopForLikeModule(torch.nn.Module):
|
|||
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(6, 8, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True)
|
||||
])
|
||||
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
sum = 0
|
||||
while(x_val > sum):
|
||||
while x_val > sum:
|
||||
sum += 1
|
||||
return sum
|
||||
|
||||
|
@ -59,20 +55,24 @@ def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TorchPrimLoopForLikeTensorArgModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([7,9], torch.float32, True),
|
||||
])
|
||||
([7, 9], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for i in range(50):
|
||||
x = x + i
|
||||
return x
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TorchPrimLoopForLikeTensorArgModule())
|
||||
def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
|
||||
x_test = torch.zeros([7, 9]).float()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,15 +17,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
# the PyTorch op registry permanently.
|
||||
import torch_mlir._torch_mlir_custom_op_example
|
||||
|
||||
|
||||
class CustomOpExampleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops._torch_mlir_custom_op_example.identity(a)
|
||||
|
||||
|
|
|
@ -10,16 +10,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a)
|
||||
|
||||
|
@ -28,96 +30,122 @@ class DiagonalModule(torch.nn.Module):
|
|||
def DiagonalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalModule())
|
||||
def DiagonalModule_nonsquare(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalTransposedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=1, dim2=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalTransposedModule())
|
||||
def DiagonalModule_transposed(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalWithDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=0, dim2=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithDimsModule())
|
||||
def DiagonalModule_with_dims(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalWithNegativeDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule())
|
||||
def DiagonalModule_with_negative_dims(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalWithOffsetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, offset=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithOffsetModule())
|
||||
def DiagonalModule_with_offset(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalWithDimsOffsetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule())
|
||||
def DiagonalModule_with_dims_and_offset(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 0.6)
|
||||
|
||||
|
@ -28,17 +31,21 @@ class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 10)
|
||||
|
||||
|
@ -47,17 +54,21 @@ class ElementwiseGtIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 7)
|
||||
|
||||
|
@ -66,17 +77,21 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
|||
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 0.6)
|
||||
|
||||
|
@ -85,17 +100,21 @@ class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 10)
|
||||
|
||||
|
@ -104,17 +123,21 @@ class ElementwiseGeIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 7)
|
||||
|
||||
|
@ -123,17 +146,21 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 7)
|
||||
|
||||
|
@ -142,18 +169,22 @@ class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ge(x, y)
|
||||
|
||||
|
@ -162,20 +193,25 @@ class ElementwiseGeFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ge(x, y)
|
||||
|
||||
|
@ -184,18 +220,22 @@ class ElementwiseGeIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
||||
|
@ -204,20 +244,25 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
||||
|
@ -226,17 +271,21 @@ class ElementwiseGtIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0.6)
|
||||
|
||||
|
@ -245,17 +294,21 @@ class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0)
|
||||
|
||||
|
@ -264,37 +317,44 @@ class ElementwiseLtIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
|
||||
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.le(x, 0.6)
|
||||
|
||||
|
@ -303,17 +363,21 @@ class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.le(x, 10)
|
||||
|
||||
|
@ -322,17 +386,21 @@ class ElementwiseLeIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.le(x, 7)
|
||||
|
||||
|
@ -341,17 +409,21 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.le(x, 7)
|
||||
|
||||
|
@ -360,18 +432,22 @@ class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.le(x, y)
|
||||
|
||||
|
@ -380,18 +456,22 @@ class ElementwiseLeFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeFloatTensorNanModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.le(x, y)
|
||||
|
||||
|
@ -400,20 +480,25 @@ class ElementwiseLeFloatTensorNanModule(torch.nn.Module):
|
|||
def ElementwiseLeFloatTensorNanModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.le(x, y)
|
||||
|
||||
|
@ -422,18 +507,22 @@ class ElementwiseLeIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
||||
|
@ -442,20 +531,25 @@ class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
||||
|
@ -464,17 +558,21 @@ class ElementwiseLtIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 6.0)
|
||||
|
||||
|
@ -482,19 +580,24 @@ class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
|
||||
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32)
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
||||
|
@ -503,17 +606,21 @@ class ElementwiseEqIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(5, 8, low=2, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqBoolScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 1)
|
||||
|
||||
|
@ -525,36 +632,42 @@ def ElementwiseEqBoolScalarModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
|
||||
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
||||
|
@ -563,20 +676,25 @@ class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [torch.nan, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
|
||||
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
||||
|
@ -585,17 +703,21 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ne(x, 2.0)
|
||||
|
||||
|
@ -603,19 +725,24 @@ class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
|
||||
def ElementwiseNeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32))
|
||||
torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32)
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ne(x, 3)
|
||||
|
||||
|
@ -624,18 +751,22 @@ class ElementwiseNeIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(8, 5, low=2, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ne(x, y)
|
||||
|
||||
|
@ -644,20 +775,25 @@ class ElementwiseNeFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32))
|
||||
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ne(x, y)
|
||||
|
||||
|
@ -666,18 +802,22 @@ class ElementwiseNeIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseNeIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeFloatTensorStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3], torch.float32, True),
|
||||
([2, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ne(x, y)
|
||||
|
||||
|
@ -686,20 +826,25 @@ class ElementwiseNeFloatTensorStaticModule(torch.nn.Module):
|
|||
def ElementwiseNeFloatTensorStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32))
|
||||
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeIntTensorStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([8, 5], torch.int64, True),
|
||||
([5], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ne(x, y)
|
||||
|
||||
|
@ -708,16 +853,20 @@ class ElementwiseNeIntTensorStaticModule(torch.nn.Module):
|
|||
def ElementwiseNeIntTensorStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AnyBoolTrueModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
input = [False, False, True]
|
||||
return torch.ops.aten.any(input)
|
||||
|
@ -733,9 +882,11 @@ class AnyBoolFalseModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
input = [False, False, False]
|
||||
return torch.ops.aten.any(input)
|
||||
|
@ -745,17 +896,20 @@ class AnyBoolFalseModule(torch.nn.Module):
|
|||
def AnyBoolFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# =================================================================================
|
||||
|
||||
class AllBoolTrueModule(torch.nn.Module):
|
||||
|
||||
class AllBoolTrueModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
input = [True, True, True, True, True]
|
||||
return torch.ops.aten.all(input)
|
||||
|
@ -765,36 +919,44 @@ class AllBoolTrueModule(torch.nn.Module):
|
|||
def AllBoolTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# =================================================================================
|
||||
|
||||
class AllBoolFalseModule(torch.nn.Module):
|
||||
|
||||
class AllBoolFalseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
input = [True, False, True, True, False]
|
||||
return torch.ops.aten.all(input)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AllBoolFalseModule())
|
||||
def AllBoolFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseIsnanModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.isnan(x)
|
||||
|
||||
|
@ -804,17 +966,21 @@ def ElementwiseIsnanModule_basic(module, tu: TestUtils):
|
|||
x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf])
|
||||
module.forward(x)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseIsinfModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.isinf(x)
|
||||
|
||||
|
|
|
@ -11,92 +11,107 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class GridSamplerBasic1(torch.nn.Module):
|
||||
|
||||
class GridSamplerBasic1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([7, 8, 12, 4], torch.float32, True),
|
||||
([7, 11, 13, 2], torch.float32, True)
|
||||
])
|
||||
([7, 11, 13, 2], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, g):
|
||||
interpolation_mode=0,
|
||||
padding_mode=0,
|
||||
align_corners=True,
|
||||
tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0],
|
||||
padding_mode[0], align_corners[0])
|
||||
interpolation_mode = (0,)
|
||||
padding_mode = (0,)
|
||||
align_corners = (True,)
|
||||
tRes = torch.ops.aten.grid_sampler(
|
||||
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
|
||||
)
|
||||
return tRes
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: GridSamplerBasic1())
|
||||
def GridSamplerBasic1_basic(
|
||||
module, tu: TestUtils):
|
||||
inp = torch.rand(7,8,12,4)
|
||||
grd = torch.rand(7,11,13,2)*2.0-1.0
|
||||
|
||||
@register_test_case(module_factory=lambda: GridSamplerBasic1())
|
||||
def GridSamplerBasic1_basic(module, tu: TestUtils):
|
||||
inp = torch.rand(7, 8, 12, 4)
|
||||
grd = torch.rand(7, 11, 13, 2) * 2.0 - 1.0
|
||||
module.forward(inp, grd)
|
||||
|
||||
|
||||
class GridSamplerBasic2(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 4, 4], torch.float32, True),
|
||||
([1, 1, 3, 2], torch.float32, True)
|
||||
])
|
||||
@annotate_args(
|
||||
[None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, g):
|
||||
interpolation_mode=0,
|
||||
padding_mode=0,
|
||||
align_corners=True,
|
||||
tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0],
|
||||
padding_mode[0], align_corners[0])
|
||||
interpolation_mode = (0,)
|
||||
padding_mode = (0,)
|
||||
align_corners = (True,)
|
||||
tRes = torch.ops.aten.grid_sampler(
|
||||
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
|
||||
)
|
||||
return tRes
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: GridSamplerBasic2())
|
||||
def GridSamplerBasic2_basic(
|
||||
module, tu: TestUtils):
|
||||
inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320],
|
||||
|
||||
@register_test_case(module_factory=lambda: GridSamplerBasic2())
|
||||
def GridSamplerBasic2_basic(module, tu: TestUtils):
|
||||
inp = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.4963, 0.7682, 0.0885, 0.1320],
|
||||
[0.3074, 0.6341, 0.4901, 0.8964],
|
||||
[0.4556, 0.6323, 0.3489, 0.4017],
|
||||
[0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor)
|
||||
grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor)
|
||||
[0.0223, 0.1689, 0.2939, 0.5185],
|
||||
]
|
||||
]
|
||||
]
|
||||
).type(torch.FloatTensor)
|
||||
grd = torch.tensor(
|
||||
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
|
||||
).type(torch.FloatTensor)
|
||||
module.forward(inp, grd)
|
||||
|
||||
|
||||
class GridSamplerBasic3(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 4, 4], torch.float32, True),
|
||||
([1, 1, 3, 2], torch.float32, True)
|
||||
])
|
||||
@annotate_args(
|
||||
[None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, g):
|
||||
interpolation_mode=0,
|
||||
padding_mode=0,
|
||||
align_corners=False,
|
||||
tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0],
|
||||
padding_mode[0], align_corners[0])
|
||||
interpolation_mode = (0,)
|
||||
padding_mode = (0,)
|
||||
align_corners = (False,)
|
||||
tRes = torch.ops.aten.grid_sampler(
|
||||
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
|
||||
)
|
||||
return tRes
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: GridSamplerBasic3())
|
||||
def GridSamplerBasic3_basic(
|
||||
module, tu: TestUtils):
|
||||
inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320],
|
||||
|
||||
@register_test_case(module_factory=lambda: GridSamplerBasic3())
|
||||
def GridSamplerBasic3_basic(module, tu: TestUtils):
|
||||
inp = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.4963, 0.7682, 0.0885, 0.1320],
|
||||
[0.3074, 0.6341, 0.4901, 0.8964],
|
||||
[0.4556, 0.6323, 0.3489, 0.4017],
|
||||
[0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor)
|
||||
grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor)
|
||||
[0.0223, 0.1689, 0.2939, 0.5185],
|
||||
]
|
||||
]
|
||||
]
|
||||
).type(torch.FloatTensor)
|
||||
grd = torch.tensor(
|
||||
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
|
||||
).type(torch.FloatTensor)
|
||||
module.forward(inp, grd)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ NUM_SEGMENTS = 42
|
|||
NUM_BINS = 5000
|
||||
NUM_LOGITS = 5000
|
||||
|
||||
|
||||
class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -45,31 +46,34 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
|||
self._iteration = 0
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, segment_value, segment_lengths, logit):
|
||||
origin_prediction = torch.sigmoid(
|
||||
logit + torch.log(self.positive_weight))
|
||||
origin_prediction = torch.sigmoid(logit + torch.log(self.positive_weight))
|
||||
# TODO: If in the future this test is removed from xfail for LTC, we will probably hit some device related
|
||||
# issues below when new tensors are created on the CPU, which is currently the default behaviour.
|
||||
# The solution would be to move these tensors to ensure they are on the same device as the existing ones.
|
||||
dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32)
|
||||
validoffsets = torch.gt(
|
||||
segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits])
|
||||
segment_lengths[1 : self._num_logits + 1],
|
||||
segment_lengths[0 : self._num_logits],
|
||||
)
|
||||
gathered_segment_values = (
|
||||
segment_value[segment_lengths[0:self._num_logits].long()]+1).int()
|
||||
segment_value[segment_lengths[0 : self._num_logits].long()] + 1
|
||||
).int()
|
||||
dense_segment_value = torch.where(
|
||||
validoffsets, gathered_segment_values, dense_segment_value)
|
||||
zeros = torch.empty_like(
|
||||
dense_segment_value, dtype=torch.int32).fill_(0)
|
||||
validoffsets, gathered_segment_values, dense_segment_value
|
||||
)
|
||||
zeros = torch.empty_like(dense_segment_value, dtype=torch.int32).fill_(0)
|
||||
isnotvalid = torch.gt(dense_segment_value, self._num_segments)
|
||||
dense_segment_value = torch.where(
|
||||
isnotvalid, zeros, dense_segment_value)
|
||||
bin_ids_data = torch.ceil(origin_prediction/self.step)-1
|
||||
dense_segment_value = torch.where(isnotvalid, zeros, dense_segment_value)
|
||||
bin_ids_data = torch.ceil(origin_prediction / self.step) - 1
|
||||
bin_ids_data = bin_ids_data.long()
|
||||
curr_segment_value = dense_segment_value * self._num_bins
|
||||
bin_ids_data2 = bin_ids_data
|
||||
|
@ -78,12 +82,14 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
|||
curr_bin_num_examples = self._bin_num_examples[bin_ids_data]
|
||||
curr_segment_value = curr_segment_value / curr_bin_num_examples
|
||||
curr_segment_value = curr_segment_value.float()
|
||||
curr_segment_value = curr_segment_value * self.bin_ctr_weight_value + \
|
||||
origin_prediction * self.oneminusbin_ctr_weight_value
|
||||
isvalid = torch.gt(curr_bin_num_examples,
|
||||
self.bin_ctr_in_use_after)
|
||||
curr_segment_value = (
|
||||
curr_segment_value * self.bin_ctr_weight_value
|
||||
+ origin_prediction * self.oneminusbin_ctr_weight_value
|
||||
)
|
||||
isvalid = torch.gt(curr_bin_num_examples, self.bin_ctr_in_use_after)
|
||||
calibrated_prediction_data = torch.where(
|
||||
isvalid, curr_segment_value, origin_prediction.float())
|
||||
isvalid, curr_segment_value, origin_prediction.float()
|
||||
)
|
||||
return calibrated_prediction_data, bin_ids_data
|
||||
|
||||
|
||||
|
@ -92,11 +98,11 @@ def HBC_basic(module, tu: TestUtils):
|
|||
logits = tu.rand(NUM_LOGITS)
|
||||
segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int)
|
||||
segment_offsets: Tensor = torch.cumsum(segment_lengths, 0)
|
||||
segment_offsets: Tensor = torch.cat(
|
||||
(torch.tensor([0]), segment_offsets), 0)
|
||||
segment_offsets: Tensor = torch.cat((torch.tensor([0]), segment_offsets), 0)
|
||||
num_values: int = int(torch.sum(segment_lengths).item())
|
||||
segment_values: Tensor = tu.randint(num_values, high=NUM_SEGMENTS)
|
||||
segment_values = torch.cat(
|
||||
(segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0)
|
||||
(segment_values, torch.zeros(NUM_LOGITS - segment_values.numel())), 0
|
||||
)
|
||||
module.forward(segment_values.int(), segment_offsets.int(), logits)
|
||||
#input shape (5000, 5001, 5000)
|
||||
# input shape (5000, 5001, 5000)
|
||||
|
|
|
@ -17,15 +17,17 @@ class IndexSelectSingleIdxModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
|
||||
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
||||
|
@ -36,33 +38,38 @@ class IndexSelectRank0IdxModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectRank0IdxModule())
|
||||
def IndexSelectRank0IdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor(2))
|
||||
|
||||
|
||||
class IndexSelectNegativeDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, -1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectNegativeDimModule())
|
||||
def IndexSelectNegativeDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
||||
|
@ -73,15 +80,17 @@ class IndexSelectTwoIdxModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
|
||||
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2, 4]))
|
||||
|
@ -92,15 +101,17 @@ class IndexSelectWholeDimensionModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([4], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 0, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
|
||||
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 1, 2, 3]))
|
||||
|
@ -111,15 +122,17 @@ class IndexSelectWholeTensorModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3], torch.float32, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 0, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
|
||||
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), torch.tensor([0, 1, 2]))
|
||||
|
@ -130,15 +143,17 @@ class IndexSelectDynamicModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
|
||||
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 4]))
|
||||
|
@ -149,15 +164,17 @@ class IndexSelectDynamicInputSizeModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
|
||||
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 2]))
|
||||
|
@ -168,15 +185,17 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
|
||||
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([1, 2]))
|
||||
|
|
|
@ -11,16 +11,19 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulDot(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -29,18 +32,22 @@ class MatmulDot(torch.nn.Module):
|
|||
def Matmul_dot(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -49,18 +56,22 @@ class Matmul2D(torch.nn.Module):
|
|||
def Matmul_2d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulVecMat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -69,18 +80,22 @@ class MatmulVecMat(torch.nn.Module):
|
|||
def Matmul_vecmat(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulMatVec(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -89,18 +104,22 @@ class MatmulMatVec(torch.nn.Module):
|
|||
def Matmul_matvec(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul3D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -109,18 +128,22 @@ class Matmul3D(torch.nn.Module):
|
|||
def Matmul_3d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul4d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -129,18 +152,22 @@ class Matmul4d(torch.nn.Module):
|
|||
def Matmul_4d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul4dStatic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 5, 6, 7], torch.float32, True),
|
||||
([4, 5, 7, 6], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -149,18 +176,22 @@ class Matmul4dStatic(torch.nn.Module):
|
|||
def Matmul4dStatic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulStaticBroadcast(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 1, 6, 7], torch.float32, True),
|
||||
([8, 1, 5, 7, 6], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -169,18 +200,22 @@ class MatmulStaticBroadcast(torch.nn.Module):
|
|||
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -189,18 +224,22 @@ class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
|||
def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulBroadcastBatchDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
@ -209,16 +248,19 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
|
|||
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Mv(torch.nn.Module):
|
||||
|
||||
class Mv(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, m, v):
|
||||
return torch.mv(m, v)
|
||||
|
||||
|
@ -227,16 +269,19 @@ class Mv(torch.nn.Module):
|
|||
def Mv_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmFloatTypes(torch.nn.Module):
|
||||
|
||||
class AtenMmFloatTypes(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.mm(a, b)
|
||||
|
||||
|
@ -245,16 +290,19 @@ class AtenMmFloatTypes(torch.nn.Module):
|
|||
def AtenMmFloatTypes_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(8, 8), tu.rand(8, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmIntTypes(torch.nn.Module):
|
||||
|
||||
class AtenMmIntTypes(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.mm(a, b)
|
||||
|
||||
|
@ -266,17 +314,19 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmQint8(torch.nn.Module):
|
||||
|
||||
class AtenMmQint8(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4], torch.int8, True),
|
||||
([4, 3], torch.int8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -285,24 +335,30 @@ class AtenMmQint8(torch.nn.Module):
|
|||
qz = torch.mm(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmQint8())
|
||||
def AtenMmQint8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, 3, low=-128, high=127).to(torch.int8))
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, 3, low=-128, high=127).to(torch.int8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmQuint8(torch.nn.Module):
|
||||
|
||||
class AtenMmQuint8(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4], torch.uint8, True),
|
||||
([4, 3], torch.uint8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -311,24 +367,30 @@ class AtenMmQuint8(torch.nn.Module):
|
|||
qz = torch.mm(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
||||
def AtenMmQuint8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=0, high=255).to(torch.uint8),
|
||||
tu.randint(4, 3, low=0, high=255).to(torch.uint8))
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=0, high=255).to(torch.uint8),
|
||||
tu.randint(4, 3, low=0, high=255).to(torch.uint8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmQMixedSigni8(torch.nn.Module):
|
||||
|
||||
class AtenMmQMixedSigni8(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4], torch.int8, True),
|
||||
([4, 3], torch.uint8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -337,24 +399,30 @@ class AtenMmQMixedSigni8(torch.nn.Module):
|
|||
qz = torch.mm(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
|
||||
def AtenMmQMixedSigni8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, 3, low=0, high=255).to(torch.uint8))
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, 3, low=0, high=255).to(torch.uint8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMatmulQint8VM(torch.nn.Module):
|
||||
|
||||
class AtenMatmulQint8VM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int8, True),
|
||||
([-1,-1], torch.int8, True),
|
||||
])
|
||||
([-1, -1], torch.int8, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -363,23 +431,28 @@ class AtenMatmulQint8VM(torch.nn.Module):
|
|||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
|
||||
def AtenMatmulQint8VM_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(9, 4, low=-128, high=127).to(torch.int8))
|
||||
module.forward(
|
||||
tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(9, 4, low=-128, high=127).to(torch.int8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class AtenMatmulQint8VV(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int8, True),
|
||||
([-1], torch.int8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -388,23 +461,28 @@ class AtenMatmulQint8VV(torch.nn.Module):
|
|||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
|
||||
def AtenMatmulQint8VV_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(9, low=-128, high=127).to(torch.int8))
|
||||
module.forward(
|
||||
tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class AtenMatmulQint8MV(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int8, True),
|
||||
([-1], torch.int8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -413,23 +491,28 @@ class AtenMatmulQint8MV(torch.nn.Module):
|
|||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
|
||||
def AtenMatmulQint8MV_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, low=-128, high=127).to(torch.int8))
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, low=-128, high=127).to(torch.int8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class AtenMatmulQint8(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, -1, 3, 4], torch.int8, True),
|
||||
([-1, 4, 3], torch.int8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -438,24 +521,30 @@ class AtenMatmulQint8(torch.nn.Module):
|
|||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8())
|
||||
def AtenMatmulQint8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 7, 3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(7, 4, 3, low=-128, high=127).to(torch.int8))
|
||||
module.forward(
|
||||
tu.randint(4, 7, 3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(7, 4, 3, low=-128, high=127).to(torch.int8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMatmulQMixedSigni8(torch.nn.Module):
|
||||
|
||||
class AtenMatmulQMixedSigni8(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([7, -1, -1, -1], torch.int8, True),
|
||||
([-1, -1, -1], torch.uint8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -464,24 +553,30 @@ class AtenMatmulQMixedSigni8(torch.nn.Module):
|
|||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
|
||||
def AtenMatmulQMixedSigni8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(2, 4, 3, low=0, high=255).to(torch.uint8))
|
||||
module.forward(
|
||||
tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(2, 4, 3, low=0, high=255).to(torch.uint8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
||||
|
||||
class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([7, -1, -1, -1], torch.int8, True),
|
||||
([-1, -1, -1], torch.uint8, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||
qx = torch.dequantize(qx)
|
||||
|
@ -491,21 +586,27 @@ class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
|||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())
|
||||
def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(2, 6, 4, low=0, high=255).to(torch.uint8))
|
||||
module.forward(
|
||||
tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(2, 6, 4, low=0, high=255).to(torch.uint8),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenLinalgCrossInt(torch.nn.Module):
|
||||
|
||||
class AtenLinalgCrossInt(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3], torch.int64, True),
|
||||
([2, 3], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.linalg_cross(a, b)
|
||||
|
||||
|
@ -514,16 +615,19 @@ class AtenLinalgCrossInt(torch.nn.Module):
|
|||
def AtenLinalgCrossInt_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 3), tu.randint(2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenLinalgCrossFloat(torch.nn.Module):
|
||||
|
||||
class AtenLinalgCrossFloat(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3], torch.float32, True),
|
||||
([2, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.linalg_cross(a, b)
|
||||
|
||||
|
@ -535,14 +639,16 @@ def AtenLinalgCrossFloat_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenLinalgCrossBroadcast(torch.nn.Module):
|
||||
|
||||
class AtenLinalgCrossBroadcast(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 4, 3], torch.float32, True),
|
||||
([5, 4, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.linalg_cross(a, b)
|
||||
|
||||
|
@ -551,16 +657,19 @@ class AtenLinalgCrossBroadcast(torch.nn.Module):
|
|||
def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenLinalgCrossCustomDim(torch.nn.Module):
|
||||
|
||||
class AtenLinalgCrossCustomDim(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 4, 3, 2, 2], torch.float32, True),
|
||||
([5, 4, 3, 2, 1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.linalg_cross(a, b, dim=2)
|
||||
|
||||
|
@ -569,16 +678,19 @@ class AtenLinalgCrossCustomDim(torch.nn.Module):
|
|||
def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenLinalgCrossNegativeDim(torch.nn.Module):
|
||||
|
||||
class AtenLinalgCrossNegativeDim(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 4, 3, 2, 2], torch.float32, True),
|
||||
([5, 4, 3, 2, 1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.linalg_cross(a, b, dim=-3)
|
||||
|
||||
|
@ -587,18 +699,22 @@ class AtenLinalgCrossNegativeDim(torch.nn.Module):
|
|||
def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenLinalgCrossDynamic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.linalg_cross(a, b, dim=1)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# Multi-layer perceptron (MLP) models.
|
||||
|
||||
|
||||
class Mlp1LayerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -21,18 +22,23 @@ class Mlp1LayerModule(torch.nn.Module):
|
|||
torch.manual_seed(0)
|
||||
self.fc0 = nn.Linear(3, 5)
|
||||
self.tanh0 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mlp1LayerModule())
|
||||
def Mlp1LayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
|
||||
class Mlp2LayerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -43,20 +49,25 @@ class Mlp2LayerModule(torch.nn.Module):
|
|||
self.tanh0 = nn.Tanh()
|
||||
self.fc1 = nn.Linear(N_HIDDEN, 2)
|
||||
self.tanh1 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.tanh0(self.fc0(x))
|
||||
x = self.tanh1(self.fc1(x))
|
||||
return x
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mlp2LayerModule())
|
||||
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
|
||||
class Mlp2LayerModuleNoBias(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -67,20 +78,25 @@ class Mlp2LayerModuleNoBias(torch.nn.Module):
|
|||
self.tanh0 = nn.Tanh()
|
||||
self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False)
|
||||
self.tanh1 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.tanh0(self.fc0(x))
|
||||
x = self.tanh1(self.fc1(x))
|
||||
return x
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias())
|
||||
def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
|
||||
class BatchMlpLayerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -88,14 +104,18 @@ class BatchMlpLayerModule(torch.nn.Module):
|
|||
torch.manual_seed(0)
|
||||
self.fc0 = nn.Linear(3, 5)
|
||||
self.tanh0 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
|
||||
def BatchMlpLayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(7, 5, 3))
|
||||
|
|
|
@ -13,23 +13,22 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class NllLossModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=0, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule())
|
||||
|
@ -42,18 +41,18 @@ class NllLossModule_mean(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=2)
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=1, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
||||
|
@ -66,18 +65,18 @@ class NllLossModule_sum(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=2)
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=2, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
||||
|
@ -90,18 +89,18 @@ class NllLossModule_1D(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=0, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
||||
|
@ -110,409 +109,465 @@ def NllLossModule_1D_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
# None of the index is ignored here, since the ignored index is out of bounds.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10)
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=0, ignore_index=10
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
class NllLossModule_backward(torch.nn.Module):
|
||||
|
||||
class NllLossModule_backward(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward())
|
||||
def NllLossModuleBackward_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backwardWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
|
||||
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
tu.rand(4), torch.tensor(3.))
|
||||
|
||||
module.forward(
|
||||
tu.rand(3),
|
||||
tu.rand(3, 4),
|
||||
torch.tensor([2, 3, 0]),
|
||||
tu.rand(4),
|
||||
torch.tensor(3.0),
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: NllLossModule_backward_ignore_index())
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward_ignore_index())
|
||||
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backwardMean(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMean())
|
||||
def NllLossModuleBackwardMean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
|
||||
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
tu.rand(4), torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(1),
|
||||
tu.rand(3, 4),
|
||||
torch.tensor([2, 3, 0]),
|
||||
tu.rand(4),
|
||||
torch.tensor(3.0),
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backwardSum(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSum())
|
||||
def NllLossModuleBackwardSum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backwardSumWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
|
||||
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
tu.rand(4), torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(1),
|
||||
tu.rand(3, 4),
|
||||
torch.tensor([2, 3, 0]),
|
||||
tu.rand(4),
|
||||
torch.tensor(3.0),
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backward1D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1D())
|
||||
def NllLossModuleBackward1D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0))
|
||||
|
||||
|
||||
class NllLossModule_backward1DWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
|
||||
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
tu.rand(3), torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backward1DMean(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMean())
|
||||
def NllLossModuleBackward1DMean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0))
|
||||
|
||||
|
||||
class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
||||
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
tu.rand(3), torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_backward1DSum(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSum())
|
||||
def NllLossModuleBackward1DSum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0))
|
||||
|
||||
|
||||
class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
return torch.ops.aten.nll_loss_backward(
|
||||
grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
total_weight=total_weight,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
|
||||
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
tu.rand(3), torch.tensor(3.))
|
||||
module.forward(
|
||||
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
|
||||
)
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm1DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -18,15 +19,16 @@ class BatchNorm1DModule(torch.nn.Module):
|
|||
self.bn1d.eval()
|
||||
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||
self.bn1d.weight = torch.nn.Parameter(
|
||||
torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([10, 4, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.bn1d(x)
|
||||
|
||||
|
@ -35,8 +37,10 @@ class BatchNorm1DModule(torch.nn.Module):
|
|||
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -44,15 +48,16 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
|||
self.bn1d.eval()
|
||||
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||
self.bn1d.weight = torch.nn.Parameter(
|
||||
torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([10, 4], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.bn1d(x)
|
||||
|
||||
|
@ -61,8 +66,10 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
|||
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm2DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -74,10 +81,12 @@ class BatchNorm2DModule(torch.nn.Module):
|
|||
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([10, 2, 3, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.bn2d(x)
|
||||
|
||||
|
@ -86,8 +95,10 @@ class BatchNorm2DModule(torch.nn.Module):
|
|||
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 2, 3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm3DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -95,16 +106,16 @@ class BatchNorm3DModule(torch.nn.Module):
|
|||
self.bn3d.eval()
|
||||
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
|
||||
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
|
||||
self.bn3d.weight = torch.nn.Parameter(
|
||||
torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
||||
self.bn3d.bias = torch.nn.Parameter(
|
||||
torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
||||
self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
||||
self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 5, 3, 6, 4], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.bn3d(x)
|
||||
|
||||
|
@ -113,274 +124,361 @@ class BatchNorm3DModule(torch.nn.Module):
|
|||
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 3, 6, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm1DStaticShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001, cudnn_enabled=False)
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001,
|
||||
cudnn_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchNorm1DStaticShapeModule())
|
||||
def BatchNorm1DStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
module.forward(tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNorm1DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNorm1DModule())
|
||||
def NativeBatchNorm1DModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
module.forward(tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNorm2DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNorm2DModule())
|
||||
def NativeBatchNorm2DModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
module.forward(tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNorm3DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNorm3DModule())
|
||||
def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNormNoneWeightModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, None, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
x,
|
||||
None,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
|
||||
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GroupNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 4, 6, 7], torch.float32, True),
|
||||
([4], torch.float32, True),
|
||||
([4], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias):
|
||||
return torch.ops.aten.group_norm(x, 2, weight, bias, 1.0000000000000001e-05, False)
|
||||
return torch.ops.aten.group_norm(
|
||||
x, 2, weight, bias, 1.0000000000000001e-05, False
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GroupNormModule())
|
||||
def GroupNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4))
|
||||
|
||||
|
||||
class GroupNormNoWeightAndBiasModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 4, 6, 7], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.group_norm(x, 2, None, None, 1.0000000000000001e-05, False)
|
||||
return torch.ops.aten.group_norm(
|
||||
x, 2, None, None, 1.0000000000000001e-05, False
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule())
|
||||
def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 6, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeGroupNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 6, 2, 2], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias):
|
||||
return torch.ops.aten.native_group_norm(
|
||||
x, weight, bias,
|
||||
2, 6, 4, 3, 0.000001)
|
||||
return torch.ops.aten.native_group_norm(x, weight, bias, 2, 6, 4, 3, 0.000001)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeGroupNormModule())
|
||||
def NativeGroupNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeGroupNormBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 6, 2, 2], torch.float32, True),
|
||||
([2, 6, 2, 2], torch.float32, True),
|
||||
([2, 3], torch.float32, True),
|
||||
([2, 3], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, grad_out, x, mean, rstd, weight):
|
||||
return torch.ops.aten.native_group_norm_backward(
|
||||
grad_out, x, mean, rstd, weight,
|
||||
2, 6, 4, 3, [True, True, True])
|
||||
grad_out, x, mean, rstd, weight, 2, 6, 4, 3, [True, True, True]
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
|
||||
def NativeGroupNormBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6, 2, 2), tu.rand(2, 6, 2, 2), tu.rand(2, 3),
|
||||
tu.rand(2, 3), tu.rand(6))
|
||||
module.forward(
|
||||
tu.rand(2, 6, 2, 2),
|
||||
tu.rand(2, 6, 2, 2),
|
||||
tu.rand(2, 3),
|
||||
tu.rand(2, 3),
|
||||
tu.rand(6),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeLayerNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 5, 2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)
|
||||
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
||||
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
class NativeLayerNormDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)
|
||||
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
|
||||
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NormalizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.normalize(x)
|
||||
|
||||
|
@ -389,48 +487,59 @@ class NormalizeModule(torch.nn.Module):
|
|||
def NormalizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeLayerNormModule4D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)[0]
|
||||
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormModule4D())
|
||||
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||
self.ly.eval()
|
||||
self.ly.weight = torch.nn.Parameter(
|
||||
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
||||
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
||||
torch.tensor(
|
||||
[[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]
|
||||
)
|
||||
)
|
||||
self.ly.bias = torch.nn.Parameter(
|
||||
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
||||
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
||||
torch.tensor(
|
||||
[[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]
|
||||
)
|
||||
)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 5, 2, 2, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.ly(x)
|
||||
|
||||
|
@ -439,8 +548,10 @@ class LayerNormModule(torch.nn.Module):
|
|||
def LayerNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormLastDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -450,10 +561,12 @@ class LayerNormLastDimModule(torch.nn.Module):
|
|||
self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3]))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 5, 2, 2, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.ly(x)
|
||||
|
||||
|
@ -462,25 +575,33 @@ class LayerNormLastDimModule(torch.nn.Module):
|
|||
def LayerNormLastDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||
self.ly.eval()
|
||||
self.ly.weight = torch.nn.Parameter(
|
||||
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
||||
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
||||
torch.tensor(
|
||||
[[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]
|
||||
)
|
||||
)
|
||||
self.ly.bias = torch.nn.Parameter(
|
||||
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
||||
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
||||
torch.tensor(
|
||||
[[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]
|
||||
)
|
||||
)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 2, 3], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.ly(x)
|
||||
|
||||
|
@ -489,20 +610,25 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
|||
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
class AtenInstanceNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 2, 1, 3], torch.float32, True),
|
||||
([2], torch.float32, True),
|
||||
([2], torch.float32, True)
|
||||
])
|
||||
([2], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, w, b):
|
||||
return torch.ops.aten.instance_norm(x, w, b, None,
|
||||
None, True, 0.0, 1e-05, False)
|
||||
return torch.ops.aten.instance_norm(
|
||||
x, w, b, None, None, True, 0.0, 1e-05, False
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenInstanceNormModule())
|
||||
def AtenInstanceNormModule_basic(module, tu: TestUtils):
|
||||
|
|
|
@ -12,98 +12,112 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModule(torch.nn.Module):
|
||||
|
||||
class ReflectionPad2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 20, 20], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (10,10,10,10))
|
||||
return torch.ops.aten.reflection_pad2d(x, (10, 10, 10, 10))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModule())
|
||||
def ReflectionPad2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 20, 20, low=-1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleTop(torch.nn.Module):
|
||||
|
||||
class ReflectionPad2dModuleTop(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 3, 4], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (0,0,2,0))
|
||||
return torch.ops.aten.reflection_pad2d(x, (0, 0, 2, 0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop())
|
||||
def ReflectionPad2dModule_Top(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleBottom(torch.nn.Module):
|
||||
|
||||
class ReflectionPad2dModuleBottom(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 10, 10], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (0,0,0,5))
|
||||
return torch.ops.aten.reflection_pad2d(x, (0, 0, 0, 5))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom())
|
||||
def ReflectionPad2dModule_Bottom(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 10, 10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleLeft(torch.nn.Module):
|
||||
|
||||
class ReflectionPad2dModuleLeft(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 20, 20], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (15,0,0,0))
|
||||
return torch.ops.aten.reflection_pad2d(x, (15, 0, 0, 0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft())
|
||||
def ReflectionPad2dModule_Left(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 20, 20))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleRight(torch.nn.Module):
|
||||
|
||||
class ReflectionPad2dModuleRight(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 20, 20], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (0,11,0,0))
|
||||
return torch.ops.aten.reflection_pad2d(x, (0, 11, 0, 0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -12,12 +12,15 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def get_quant_model_input():
|
||||
return 2 * torch.rand((1, 16)) - 1
|
||||
|
||||
|
||||
def get_batched_quant_model_input():
|
||||
return 2 * torch.rand((1, 2, 16)) - 1
|
||||
|
||||
|
||||
class QuantizedNoLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -26,15 +29,18 @@ class QuantizedNoLayer(nn.Module):
|
|||
self.dequantize = torch.quantization.DeQuantStub()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 16], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.quantize(x)
|
||||
x = self.dequantize(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_quantized_no_layer():
|
||||
model = QuantizedNoLayer()
|
||||
model.eval()
|
||||
|
@ -46,10 +52,12 @@ def get_quantized_no_layer():
|
|||
torch.quantization.convert(model, inplace=True)
|
||||
return model
|
||||
|
||||
|
||||
@register_test_case(module_factory=get_quantized_no_layer)
|
||||
def QuantizedNoLayer_basic(module, tu: TestUtils):
|
||||
module.forward(get_quant_model_input())
|
||||
|
||||
|
||||
class QuantizedSingleLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -61,16 +69,19 @@ class QuantizedSingleLayer(nn.Module):
|
|||
self.dequantize = torch.quantization.DeQuantStub()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 16], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.quantize(x)
|
||||
x = self.layers(x)
|
||||
x = self.dequantize(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_quantized_single_layer():
|
||||
model = QuantizedSingleLayer()
|
||||
model.eval()
|
||||
|
@ -82,10 +93,12 @@ def get_quantized_single_layer():
|
|||
torch.quantization.convert(model, inplace=True)
|
||||
return model
|
||||
|
||||
|
||||
@register_test_case(module_factory=get_quantized_single_layer)
|
||||
def QuantizedSingleLayer_basic(module, tu: TestUtils):
|
||||
module.forward(get_quant_model_input())
|
||||
|
||||
|
||||
class QuantizedBatchedInputSingleLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -97,16 +110,19 @@ class QuantizedBatchedInputSingleLayer(nn.Module):
|
|||
self.dequantize = torch.quantization.DeQuantStub()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 2, 16], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.quantize(x)
|
||||
x = self.layers(x)
|
||||
x = self.dequantize(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_batched_quantized_single_layer():
|
||||
model = QuantizedBatchedInputSingleLayer()
|
||||
model.eval()
|
||||
|
@ -118,10 +134,12 @@ def get_batched_quantized_single_layer():
|
|||
torch.quantization.convert(model, inplace=True)
|
||||
return model
|
||||
|
||||
|
||||
@register_test_case(module_factory=get_batched_quantized_single_layer)
|
||||
def QuantizedBatchedInputSingleLayer_basic(module, tu: TestUtils):
|
||||
module.forward(get_batched_quant_model_input())
|
||||
|
||||
|
||||
class QuantizedMLP(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -135,16 +153,19 @@ class QuantizedMLP(nn.Module):
|
|||
self.dequantize = torch.quantization.DeQuantStub()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 16], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.quantize(x)
|
||||
x = self.layers(x)
|
||||
x = self.dequantize(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_quantized_mlp():
|
||||
model = QuantizedMLP()
|
||||
model.eval()
|
||||
|
@ -156,6 +177,7 @@ def get_quantized_mlp():
|
|||
torch.quantization.convert(model, inplace=True)
|
||||
return model
|
||||
|
||||
|
||||
@register_test_case(module_factory=get_quantized_mlp)
|
||||
def QuantizedMLP_basic(module, tu: TestUtils):
|
||||
module.forward(get_quant_model_input())
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -13,19 +13,20 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class TestMultipleTensorReturn(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b, c, d, e):
|
||||
return a, b, c, d, e
|
||||
|
||||
|
@ -37,54 +38,56 @@ def TestMultipleTensorReturn_basic(module, tu: TestUtils):
|
|||
tu.rand(2, 3).to(torch.float64),
|
||||
tu.rand(2, 3).to(torch.int32),
|
||||
tu.rand(2, 3).to(torch.int64),
|
||||
tu.rand(2, 3).to(torch.bool))
|
||||
tu.rand(2, 3).to(torch.bool),
|
||||
)
|
||||
|
||||
|
||||
class TestMultipleTensorAndPrimitiveTypesReturn(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b, c):
|
||||
d = 1
|
||||
e = 2.3
|
||||
return a, b, c, d, e
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TestMultipleTensorAndPrimitiveTypesReturn())
|
||||
@register_test_case(module_factory=lambda: TestMultipleTensorAndPrimitiveTypesReturn())
|
||||
def TestMultipleTensorAndPrimitiveTypesReturn_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(3, 4).to(torch.int32),
|
||||
tu.rand(2, 3).to(torch.float64),
|
||||
tu.rand(2, 3).to(torch.bool))
|
||||
tu.rand(2, 3).to(torch.bool),
|
||||
)
|
||||
|
||||
|
||||
class TestF16Return(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float16, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return a
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TestF16Return())
|
||||
@register_test_case(module_factory=lambda: TestF16Return())
|
||||
def TestF16Return_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(3, 4).to(torch.float16))
|
||||
module.forward(tu.rand(3, 4).to(torch.float16))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -6,16 +6,13 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class RandModule(torch.nn.Module):
|
||||
|
||||
class RandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1024, 512], torch.float, True)
|
||||
])
|
||||
@annotate_args([None, ([1024, 512], torch.float, True)])
|
||||
def forward(self, x):
|
||||
size = x.size()
|
||||
a = torch.rand(size)
|
||||
|
@ -26,34 +23,41 @@ class RandModule(torch.nn.Module):
|
|||
def RandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 512))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UniformModule(torch.nn.Module):
|
||||
|
||||
class UniformModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
||||
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
||||
std = torch.cat([
|
||||
std = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.std(c)),
|
||||
]
|
||||
)
|
||||
mean = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
torch.flatten(torch.mean(c)),
|
||||
]
|
||||
)
|
||||
return std, mean
|
||||
|
||||
|
||||
|
@ -62,36 +66,44 @@ def UniformModule_basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
tu.rand(256, 512, 12).double(),
|
||||
tu.rand(512, 1024, 12).double(),
|
||||
tu.rand(512, 256, 12).double())
|
||||
tu.rand(512, 256, 12).double(),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UniformStaticShapeModule(torch.nn.Module):
|
||||
|
||||
class UniformStaticShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([256, 512, 12], torch.float64, True),
|
||||
([512, 1024, 12], torch.float64, True),
|
||||
([512, 256, 12], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
||||
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
||||
std = torch.cat([
|
||||
std = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.std(c)),
|
||||
]
|
||||
)
|
||||
mean = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
torch.flatten(torch.mean(c)),
|
||||
]
|
||||
)
|
||||
return std, mean
|
||||
|
||||
|
||||
|
@ -100,12 +112,14 @@ def UniformStaticShapeModule_basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
tu.rand(256, 512, 12).double(),
|
||||
tu.rand(512, 1024, 12).double(),
|
||||
tu.rand(512, 256, 12).double())
|
||||
tu.rand(512, 256, 12).double(),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UniformNoCorrelationModule(torch.nn.Module):
|
||||
|
||||
class UniformNoCorrelationModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -119,10 +133,12 @@ class UniformNoCorrelationModule(torch.nn.Module):
|
|||
return cov[0, 1] / torch.sqrt(cov[0, 0] * cov[1, 1])
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1000], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
# Correlation of two independent uniforms
|
||||
a = torch.ops.aten.uniform(x)
|
||||
|
@ -145,27 +161,32 @@ class UniformNoCorrelationModule(torch.nn.Module):
|
|||
# than `atol + rtol * correlation = 1E-6`, which is too strict.
|
||||
# Instead, the correlations are explicitly required to be less than
|
||||
# 0.001.
|
||||
return torch.where(torch.abs(corr_a_b) < 0.001, 1, 2), \
|
||||
torch.where(torch.abs(corr_major) < 0.001, 1, 2), \
|
||||
torch.where(torch.abs(corr_minor) < 0.001, 1, 2)
|
||||
return (
|
||||
torch.where(torch.abs(corr_a_b) < 0.001, 1, 2),
|
||||
torch.where(torch.abs(corr_major) < 0.001, 1, 2),
|
||||
torch.where(torch.abs(corr_minor) < 0.001, 1, 2),
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UniformNoCorrelationModule())
|
||||
def UniformNoCorrelationModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(1000).double())
|
||||
module.forward(tu.rand(1000).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ExponentialModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.exponential(x, 3.0)
|
||||
mean = torch.mean(a)
|
||||
|
@ -175,20 +196,23 @@ class ExponentialModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ExponentialModule())
|
||||
def ExponentialModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 16).double())
|
||||
module.forward(tu.rand(512, 512, 16).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.bernoulli(x)
|
||||
mean = torch.mean(a)
|
||||
|
@ -198,20 +222,23 @@ class BernoulliModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BernoulliModule())
|
||||
def BernoulliModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 16).double())
|
||||
module.forward(tu.rand(512, 512, 16).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliZerosModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.bernoulli(x)
|
||||
|
||||
|
@ -220,17 +247,21 @@ class BernoulliZerosModule(torch.nn.Module):
|
|||
def BernoulliZerosModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(4, 8).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliOnesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.bernoulli(x)
|
||||
|
||||
|
@ -239,50 +270,60 @@ class BernoulliOnesModule(torch.nn.Module):
|
|||
def BernoulliOnesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones(4, 8).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
a = torch.ops.aten.bernoulli_(x, 0.4)
|
||||
b = torch.ops.aten.bernoulli_(y, 0.7)
|
||||
mean = torch.cat([
|
||||
mean = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
])
|
||||
std = torch.cat([
|
||||
]
|
||||
)
|
||||
std = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
])
|
||||
]
|
||||
)
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
||||
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 16).double(),
|
||||
tu.rand(512, 512, 16).double())
|
||||
module.forward(tu.rand(512, 512, 16).double(), tu.rand(512, 512, 16).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, px):
|
||||
a = torch.ops.aten.bernoulli_(x, px)
|
||||
mean = torch.mean(a)
|
||||
|
@ -292,53 +333,61 @@ class BernoulliTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512).double(),
|
||||
tu.rand(512, 512).double())
|
||||
module.forward(tu.rand(512, 512).double(), tu.rand(512, 512).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliPModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
a = torch.ops.aten.bernoulli(x, 0.4)
|
||||
b = torch.ops.aten.bernoulli(y, 0.7)
|
||||
mean = torch.cat([
|
||||
mean = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
])
|
||||
std = torch.cat([
|
||||
]
|
||||
)
|
||||
std = torch.cat(
|
||||
[
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
])
|
||||
]
|
||||
)
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliPModule())
|
||||
def BernoulliPModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 16).double(),
|
||||
tu.rand(512, 512, 16).double())
|
||||
module.forward(tu.rand(512, 512, 16).double(), tu.rand(512, 512, 16).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandLikeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.rand_like(x)
|
||||
mean = torch.mean(a)
|
||||
|
@ -349,17 +398,21 @@ class RandLikeModule(torch.nn.Module):
|
|||
def RandLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1024).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandLikeDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.rand_like(x, dtype=torch.float32)
|
||||
mean = torch.mean(a)
|
||||
|
@ -370,16 +423,20 @@ class RandLikeDtypeModule(torch.nn.Module):
|
|||
def RandLikeDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1024).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandIntLowModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randint(low=1, high=1000, size=[1024, 1024])
|
||||
mean = torch.mean(a.to(torch.float32))
|
||||
|
@ -390,18 +447,24 @@ class RandIntLowModule(torch.nn.Module):
|
|||
def RandIntLowModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandIntLowDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randint(low=1, high=1000, size=[128, 256, 512], dtype=torch.float64)
|
||||
a = torch.ops.aten.randint(
|
||||
low=1, high=1000, size=[128, 256, 512], dtype=torch.float64
|
||||
)
|
||||
mean = torch.mean(a)
|
||||
return mean
|
||||
|
||||
|
@ -410,16 +473,20 @@ class RandIntLowDtypeModule(torch.nn.Module):
|
|||
def RandIntLowDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randint(high=1000, size=[1024, 1024])
|
||||
mean = torch.mean(a.to(torch.float32))
|
||||
|
@ -430,16 +497,20 @@ class RandIntModule(torch.nn.Module):
|
|||
def RandIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandIntDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], dtype=torch.float64)
|
||||
mean = torch.mean(a.to(torch.float32))
|
||||
|
@ -450,16 +521,20 @@ class RandIntDtypeModule(torch.nn.Module):
|
|||
def RandIntDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandIntPinMemoryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], pin_memory=False)
|
||||
mean = torch.mean(a.to(torch.float32))
|
||||
|
@ -470,17 +545,20 @@ class RandIntPinMemoryModule(torch.nn.Module):
|
|||
def RandIntPinMemoryModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class RandnModule(torch.nn.Module):
|
||||
|
||||
class RandnModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randn([4, 512, 1024])
|
||||
std = torch.std(a.to(dtype=torch.float64))
|
||||
|
@ -496,18 +574,19 @@ def RandnModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class RandnDtypeDeviceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randn([4, 512, 1024],
|
||||
dtype=torch.float64,
|
||||
device=torch.device("cpu"))
|
||||
a = torch.ops.aten.randn(
|
||||
[4, 512, 1024], dtype=torch.float64, device=torch.device("cpu")
|
||||
)
|
||||
std = torch.std(a)
|
||||
return std
|
||||
|
||||
|
@ -521,14 +600,15 @@ def RandnDtypeDeviceModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class RandnGeneratorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randn([4, 512, 1024], generator=None)
|
||||
std = torch.std(a.to(dtype=torch.float64))
|
||||
|
@ -544,14 +624,15 @@ def RandnGeneratorModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class RandnGeneratorF64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64)
|
||||
std = torch.std(a)
|
||||
|
@ -571,10 +652,12 @@ class RandnLikeModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.randn_like(x)
|
||||
std = torch.std(a)
|
||||
|
@ -585,17 +668,21 @@ class RandnLikeModule(torch.nn.Module):
|
|||
def RandnLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 512, 1024).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandnLikeDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.randn_like(x, dtype=torch.float32)
|
||||
std = torch.std(a)
|
||||
|
@ -605,17 +692,22 @@ class RandnLikeDtypeModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
|
||||
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(256, 1024).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NormalFunctionalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
|
||||
mean = torch.mean(a)
|
||||
|
|
|
@ -13,16 +13,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class AddIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return int(lhs) + int(rhs)
|
||||
|
||||
|
@ -36,16 +37,17 @@ def AddIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SubIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return int(lhs) - int(rhs)
|
||||
|
||||
|
@ -59,16 +61,17 @@ def SubIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SubFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) - float(rhs)
|
||||
|
||||
|
@ -80,17 +83,19 @@ def SubFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class MulFloatModule(torch.nn.Module):
|
||||
|
||||
class MulFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) * float(rhs)
|
||||
|
||||
|
@ -104,16 +109,17 @@ def MulFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class MulIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return int(lhs) * int(rhs)
|
||||
|
||||
|
@ -127,16 +133,17 @@ def MulIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class DivIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
# Cast the result to float to make e2e test baseline result to be a float.
|
||||
# Without the cast, baseline result is a Tensor which is unexpected.
|
||||
|
@ -152,16 +159,17 @@ def DivIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class DivFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) / float(rhs)
|
||||
|
||||
|
@ -175,16 +183,17 @@ def DivFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class CeilFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
sub = float(lhs) - float(rhs)
|
||||
# Cast the result to int to make e2e test baseline result to be an int.
|
||||
|
@ -204,15 +213,16 @@ def CeilFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqrtIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return float(torch.ops.aten.sqrt(int(a)))
|
||||
|
||||
|
@ -223,14 +233,15 @@ def SqrtIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqrtIntConstantModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return float(torch.ops.aten.sqrt(5))
|
||||
|
||||
|
@ -244,15 +255,16 @@ def SqrtIntConstantModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class BoolFloatFalseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
sub = float(a) - float(a)
|
||||
return bool(torch.ops.aten.Bool(float(sub)))
|
||||
|
@ -264,15 +276,16 @@ def BoolFloatFalseModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class BoolFloatTrueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return bool(torch.ops.aten.Bool(float(a)))
|
||||
|
||||
|
@ -283,14 +296,15 @@ def BoolFloatTrueModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class BoolFloatConstantModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return bool(torch.ops.aten.Bool(5.0))
|
||||
|
||||
|
@ -304,15 +318,16 @@ def BoolFloatConstantModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class BoolIntFalseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
sub = int(a) - int(a)
|
||||
return bool(torch.ops.aten.Bool(int(sub)))
|
||||
|
@ -324,15 +339,16 @@ def BoolIntFalseModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class BoolIntTrueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return bool(torch.ops.aten.Bool(int(a)))
|
||||
|
||||
|
@ -343,14 +359,15 @@ def BoolIntTrueModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class BoolIntConstantModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return bool(torch.ops.aten.Bool(5))
|
||||
|
||||
|
@ -359,17 +376,21 @@ class BoolIntConstantModule(torch.nn.Module):
|
|||
def BoolIntConstantModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenIntBoolOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.bool, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return int(torch.ops.aten.Int(x))
|
||||
|
||||
|
@ -384,9 +405,11 @@ class AtenIntBoolOpConstTrueModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return int(torch.ops.aten.Int(True))
|
||||
|
||||
|
@ -401,9 +424,11 @@ class AtenIntBoolOpConstFalseModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return int(torch.ops.aten.Int(False))
|
||||
|
||||
|
@ -412,21 +437,25 @@ class AtenIntBoolOpConstFalseModule(torch.nn.Module):
|
|||
def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.uint8, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule())
|
||||
def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8))
|
||||
|
@ -434,57 +463,68 @@ def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenIntTensorCharDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int8, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
|
||||
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenItemIntOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int8, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenItemIntOpModule())
|
||||
def AtenItemIntOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenItemFpOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
return float(val)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenItemFpOpModule())
|
||||
def AtenItemFpOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
|
|
@ -13,16 +13,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class NeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return int(lhs) != int(rhs)
|
||||
|
||||
|
@ -36,16 +37,17 @@ def NeIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class EqIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return int(lhs) == int(rhs)
|
||||
|
||||
|
@ -59,16 +61,17 @@ def EqIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class GtIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return int(lhs) > int(rhs)
|
||||
|
||||
|
@ -82,16 +85,17 @@ def GtIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class GeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.ops.aten.ge(int(lhs), int(rhs))
|
||||
|
||||
|
@ -105,16 +109,17 @@ def GeIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class GeFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) >= float(rhs)
|
||||
|
||||
|
@ -128,16 +133,17 @@ def GeFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class GeFloatIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) >= int(rhs)
|
||||
|
||||
|
@ -151,16 +157,17 @@ def GeFloatIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class NeFloatIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) != int(rhs)
|
||||
|
||||
|
@ -174,16 +181,17 @@ def NeFloatIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class GtFloatIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) > int(rhs)
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -17,16 +17,17 @@ class SqueezeStaticModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 7, 1, 3, 1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeStaticModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeStaticModule())
|
||||
def SqueezeModule_static(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 7, 1, 3, 1))
|
||||
|
||||
|
@ -39,16 +40,17 @@ class SqueezeAllUnitDimModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeAllUnitDimModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeAllUnitDimModule())
|
||||
def SqueezeModule_allUnitDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1))
|
||||
|
||||
|
@ -61,17 +63,18 @@ class SqueezeBroadcastModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return a * b.squeeze()
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeBroadcastModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeBroadcastModule())
|
||||
def SqueezeModule_broadcast(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3), tu.rand())
|
||||
|
||||
|
@ -84,16 +87,17 @@ class SqueezeDimStaticModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 7], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimStaticModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimStaticModule())
|
||||
def SqueezeDimModule_static(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 7))
|
||||
|
||||
|
@ -106,16 +110,17 @@ class SqueezeDimDynamicModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, 1, 384, -1, 1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 4)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimDynamicModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimDynamicModule())
|
||||
def SqueezeDimModule_dynamic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(8, 1, 384, 12, 1))
|
||||
|
||||
|
@ -128,16 +133,17 @@ class SqueezeDimNegDimModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, -1, 1, 384, -1, 1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, -6)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimNegDimModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimNegDimModule())
|
||||
def SqueezeDimModule_negDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
|
||||
|
||||
|
@ -150,16 +156,17 @@ class SqueezeDimIdentityModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimIdentityModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimIdentityModule())
|
||||
def SqueezeDimModule_identity(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 1, 3))
|
||||
|
||||
|
@ -172,16 +179,17 @@ class SqueezeDimUnitDimModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimUnitDimModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimUnitDimModule())
|
||||
def SqueezeDimModule_unitDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
||||
|
@ -194,16 +202,17 @@ class PrimsSqueezeModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 1, 2, 3, 1, 4], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.prims.squeeze(a, dimensions=[0, 4, 1])
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: PrimsSqueezeModule())
|
||||
@register_test_case(module_factory=lambda: PrimsSqueezeModule())
|
||||
def PrimsSqueezeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 2, 3, 1, 4))
|
||||
|
||||
|
@ -213,15 +222,16 @@ class PrimsSqueezeEmptyDimensionsModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 2, 1, 4], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.prims.squeeze(a, dimensions=[])
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: PrimsSqueezeEmptyDimensionsModule())
|
||||
@register_test_case(module_factory=lambda: PrimsSqueezeEmptyDimensionsModule())
|
||||
def PrimsSqueezeEmptyDimensionsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 2, 1, 4))
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,14 +17,16 @@ class Threshold1dIntI32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
|
||||
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10).to(torch.int32))
|
||||
|
@ -35,14 +37,16 @@ class Threshold1dIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntModule())
|
||||
def Threshold1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10))
|
||||
|
@ -53,14 +57,16 @@ class Threshold2dIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold2dIntModule())
|
||||
def Threshold2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=10))
|
||||
|
@ -71,14 +77,16 @@ class Threshold3dIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2.2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold3dIntModule())
|
||||
def Threshold3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, 6, high=10))
|
||||
|
@ -89,14 +97,16 @@ class Threshold1dFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
|
||||
def Threshold1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4))
|
||||
|
@ -107,14 +117,16 @@ class Threshold2dFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
|
||||
def Threshold2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5))
|
||||
|
@ -125,14 +137,16 @@ class Threshold3dFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1.4, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
|
||||
def Threshold3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6))
|
||||
|
@ -143,15 +157,17 @@ class ThresholdBackward1dIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
|
||||
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10), tu.randint(4, high=8))
|
||||
|
@ -162,15 +178,17 @@ class ThresholdBackward2dIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
|
||||
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8))
|
||||
|
@ -181,15 +199,17 @@ class ThresholdBackward3dIntModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
|
||||
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8))
|
||||
|
@ -200,15 +220,17 @@ class ThresholdBackward1dFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
|
||||
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4))
|
||||
|
@ -219,15 +241,17 @@ class ThresholdBackward2dFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
|
||||
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), tu.rand(4, 5))
|
||||
|
@ -238,15 +262,17 @@ class ThresholdBackward3dFloatModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
|
||||
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6))
|
||||
|
@ -257,15 +283,17 @@ class ThresholdBackward1dMixedModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
||||
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.randint(4, high=10))
|
||||
|
@ -276,15 +304,17 @@ class ThresholdBackward2dMixedModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
||||
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=20), tu.rand(4, 5))
|
||||
|
@ -295,15 +325,17 @@ class ThresholdBackward3dMixedModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
]
|
||||
)
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
||||
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), tu.randint(4, 5, 6, high=10))
|
||||
|
|
|
@ -13,7 +13,6 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class TypeConversionF32ToF64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -29,7 +28,6 @@ def TypeConversionF32ToF64Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeConversionF64ToF32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -45,7 +43,6 @@ def TypeConversionF64ToF32Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeConversionI32ToI64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -61,7 +58,6 @@ def TypeConversionI32ToI64Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeConversionI64ToI32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -77,7 +73,6 @@ def TypeConversionI64ToI32Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeConversionI1ToI32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -94,7 +89,6 @@ def TypeConversionI1ToI32Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeConversionI1ToI64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -111,7 +105,6 @@ def TypeConversionI1ToI64Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeConversionI1ToF32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -128,7 +121,6 @@ def TypeConversionI1ToF32Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeConversionI1ToF64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -148,43 +140,46 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class ToDtypeLayoutNoneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
return torch.ops.aten.to(
|
||||
x,
|
||||
dtype=torch.float64,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
non_blocking=False,
|
||||
copy=False,
|
||||
memory_format=None)
|
||||
memory_format=None,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToDtypeLayoutNoneModule())
|
||||
def ToDtypeLayoutNoneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
class ToDtypeLayoutCPUModule(torch.nn.Module):
|
||||
|
||||
class ToDtypeLayoutCPUModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
return torch.ops.aten.to(
|
||||
x,
|
||||
dtype=torch.float64,
|
||||
layout=None,
|
||||
device="cpu",
|
||||
pin_memory=None,
|
||||
non_blocking=False,
|
||||
copy=False,
|
||||
memory_format=None)
|
||||
memory_format=None,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToDtypeLayoutCPUModule())
|
||||
|
@ -193,21 +188,22 @@ def ToDtypeLayoutCPUModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class ToDtypeLayoutStridedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
return torch.ops.aten.to(
|
||||
x,
|
||||
dtype=torch.float64,
|
||||
layout=torch.strided,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
non_blocking=False,
|
||||
copy=False,
|
||||
memory_format=None)
|
||||
memory_format=None,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule())
|
||||
|
@ -216,21 +212,22 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([3, 5], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
return torch.ops.aten.to(
|
||||
x,
|
||||
dtype=torch.bool,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
non_blocking=False,
|
||||
copy=False,
|
||||
memory_format=None)
|
||||
memory_format=None,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule())
|
||||
|
@ -239,16 +236,17 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypeAsSameModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.type_as(x, y)
|
||||
|
||||
|
@ -257,17 +255,19 @@ class TypeAsSameModule(torch.nn.Module):
|
|||
def TypeAsSameModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
|
||||
class TypeAsDifferentModule(torch.nn.Module):
|
||||
|
||||
class TypeAsDifferentModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.type_as(x, y)
|
||||
|
||||
|
@ -276,14 +276,14 @@ class TypeAsDifferentModule(torch.nn.Module):
|
|||
def TypeAsDifferentModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 5, low=0, high=10, dtype=torch.int),
|
||||
tu.randint(3, 5, low=0, high=10, dtype=torch.int64)
|
||||
tu.randint(3, 5, low=0, high=10, dtype=torch.int64),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class PrimsConvertElementTypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -17,21 +17,22 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
|
||||
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule()
|
||||
)
|
||||
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, high=10).type(torch.int32),
|
||||
tu.randint(4, high=10))
|
||||
module.forward(tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10))
|
||||
|
||||
|
||||
class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
||||
|
@ -39,17 +40,18 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionDifferentCategoryModule())
|
||||
@register_test_case(module_factory=lambda: TypePromotionDifferentCategoryModule())
|
||||
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10), tu.rand(4))
|
||||
|
||||
|
@ -59,17 +61,20 @@ class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=2.3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule())
|
||||
module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule()
|
||||
)
|
||||
def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand().type(torch.float64))
|
||||
|
||||
|
@ -79,17 +84,18 @@ class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
|
||||
@register_test_case(module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
|
||||
def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10), tu.rand())
|
||||
|
||||
|
@ -99,11 +105,13 @@ class TypePromotionAlphaWiderModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=2.3)
|
||||
|
||||
|
|
|
@ -22,10 +22,12 @@ class ResNet18Module(torch.nn.Module):
|
|||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, 3, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
|
@ -44,10 +46,12 @@ class ResNet18StaticModule(torch.nn.Module):
|
|||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 3, 224, 224], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
|
@ -62,11 +66,13 @@ class IouOfModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, bbox1, bbox2):
|
||||
area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1])
|
||||
area2 = (bbox2[:, 2] - bbox2[:, 0]) * (bbox2[:, 3] - bbox2[:, 1])
|
||||
|
@ -94,10 +100,12 @@ class MobilenetV3Module(torch.nn.Module):
|
|||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, 3, -1, -1], torch.float32, True),
|
||||
])
|
||||
]
|
||||
)
|
||||
def forward(self, img):
|
||||
return self.mobilenetv3.forward(img)
|
||||
|
||||
|
|
|
@ -13,12 +13,12 @@ from torch_mlir.ir import Module
|
|||
# A type shared between the result of `TosaBackend.compile` and the
|
||||
# input to `TosaBackend.load`. Each backend will likely have a
|
||||
# different definition of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
# that uniformly translates the `x.method(...)` interface expected of
|
||||
# Torch modules into appropriate lower-level operations.
|
||||
Invoker = TypeVar('Invoker')
|
||||
Invoker = TypeVar("Invoker")
|
||||
|
||||
|
||||
class TosaBackend(abc.ABC):
|
||||
|
@ -27,6 +27,7 @@ class TosaBackend(abc.ABC):
|
|||
Backends are recommended to raise meaningful exceptions in case of error,
|
||||
ideally with easy reproduction instructions.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
|
|
@ -7,7 +7,9 @@ from torch_mlir.ir import *
|
|||
from torch_mlir.passmanager import *
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
|
||||
from .abc import TosaBackend
|
||||
|
||||
|
@ -17,7 +19,8 @@ __all__ = [
|
|||
|
||||
# The pipeline of func.func passes that lower the TOSA backend contract to the
|
||||
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||
TOSA_TO_LINALG_FUNC_PIPELINE = ",".join([
|
||||
TOSA_TO_LINALG_FUNC_PIPELINE = ",".join(
|
||||
[
|
||||
# TOSA legalization may emit tosa.const() ops. These are legalized
|
||||
# by tosa-to-arith to arith.constants. This mechanical transformation
|
||||
# must be done prior to TOSA-to-LinAlg so that the latter does not fail.
|
||||
|
@ -33,7 +36,8 @@ TOSA_TO_LINALG_FUNC_PIPELINE = ",".join([
|
|||
"tosa-to-linalg",
|
||||
"tosa-to-tensor",
|
||||
"tosa-to-arith",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class LinalgOnTensorsTosaBackend(TosaBackend):
|
||||
|
@ -60,7 +64,8 @@ class LinalgOnTensorsTosaBackend(TosaBackend):
|
|||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
f"builtin.module(func.func({TOSA_TO_LINALG_FUNC_PIPELINE}))",
|
||||
"Lowering TOSA backend contract to Linalg-on-Tensors backend contract")
|
||||
"Lowering TOSA backend contract to Linalg-on-Tensors backend contract",
|
||||
)
|
||||
|
||||
return self.refbackend.compile(imported_module)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
from torch_mlir.torchscript import TensorPlaceholder
|
||||
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||
|
||||
|
||||
def convert_annotations_to_placeholders(forward_method):
|
||||
"""Converts the annotations on a forward method into tensor placeholders.
|
||||
|
||||
|
@ -17,6 +18,7 @@ def convert_annotations_to_placeholders(forward_method):
|
|||
for annotation in annotations[1:]:
|
||||
if not annotation[2]:
|
||||
raise ValueError(
|
||||
"Can only compile inputs annotated as having value semantics.")
|
||||
"Can only compile inputs annotated as having value semantics."
|
||||
)
|
||||
placeholders.append(TensorPlaceholder(annotation[0], annotation[1]))
|
||||
return placeholders
|
||||
|
|
|
@ -19,45 +19,52 @@ from lit.llvm.subst import FindTool
|
|||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'TORCH_MLIR_PT1'
|
||||
config.name = "TORCH_MLIR_PT1"
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.mlir', '.py']
|
||||
config.suffixes = [".mlir", ".py"]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||
|
||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||
config.substitutions.append(("%PATH%", config.environment["PATH"]))
|
||||
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
|
||||
|
||||
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
|
||||
|
||||
#llvm_config.use_default_substitutions()
|
||||
# llvm_config.use_default_substitutions()
|
||||
|
||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||
# directories.
|
||||
config.excludes = [
|
||||
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
|
||||
'lit.cfg.py', 'lit.site.cfg.py'
|
||||
"Inputs",
|
||||
"Examples",
|
||||
"CMakeLists.txt",
|
||||
"README.txt",
|
||||
"LICENSE.txt",
|
||||
"lit.cfg.py",
|
||||
"lit.site.cfg.py",
|
||||
]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin")
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
|
||||
# Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree.
|
||||
llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True)
|
||||
llvm_config.with_environment(
|
||||
"PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True
|
||||
)
|
||||
|
||||
# On Windows the path to python could contains spaces in which case it needs to
|
||||
# be provided in quotes. This is the equivalent of how %python is setup in
|
||||
|
@ -65,16 +72,23 @@ llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'),
|
|||
if "Windows" in config.host_os:
|
||||
config.python_executable = '"%s"' % (config.python_executable)
|
||||
|
||||
tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir, config.torch_mlir_obj_root]
|
||||
tool_dirs = [
|
||||
config.standalone_tools_dir,
|
||||
config.llvm_tools_dir,
|
||||
config.torch_mlir_obj_root,
|
||||
]
|
||||
tools = [
|
||||
'torch-mlir-opt',
|
||||
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
|
||||
"torch-mlir-opt",
|
||||
ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"),
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
||||
if config.enable_bindings_python:
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'),
|
||||
llvm_config.with_environment(
|
||||
"PYTHONPATH",
|
||||
[
|
||||
os.path.join(config.torch_mlir_python_packages_dir, "torch_mlir"),
|
||||
],
|
||||
append_path=True)
|
||||
append_path=True,
|
||||
)
|
||||
|
|
|
@ -20,18 +20,26 @@ goofy_lib = torch.library.Library("goofy", "DEF")
|
|||
goofy_lib.define("identity(Tensor t) -> Tensor")
|
||||
goofy_lib.impl("identity", identity)
|
||||
|
||||
|
||||
def goofy〇identity〡shape(t: List[int]) -> List[int]:
|
||||
return t
|
||||
|
||||
|
||||
def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
|
||||
t_rank, t_dtype = t_rank_dtype
|
||||
return t_dtype
|
||||
|
||||
|
||||
def goofy〇identity〡has_value_semantics() -> None:
|
||||
return
|
||||
|
||||
|
||||
extra_library = [
|
||||
goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics]
|
||||
goofy〇identity〡shape,
|
||||
goofy〇identity〡dtype,
|
||||
goofy〇identity〡has_value_semantics,
|
||||
]
|
||||
|
||||
|
||||
class CustomOpExampleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -52,6 +60,7 @@ class CustomOpExampleModule(torch.nn.Module):
|
|||
mod = CustomOpExampleModule()
|
||||
mod.eval()
|
||||
|
||||
|
||||
def run():
|
||||
mod = CustomOpExampleModule()
|
||||
mod.eval()
|
||||
|
@ -66,6 +75,7 @@ def run():
|
|||
|
||||
print(module)
|
||||
|
||||
|
||||
run()
|
||||
|
||||
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue