mirror of https://github.com/llvm/torch-mlir
[NFC reformat] Applies pre-commit formatting to Python files. (#3244)
This is a large change because prior to this point, Python files in the project were not consistently formatted. This reformats them all with black defaults. Based on experience with prior projects, if you have a dev/long-term branch with Python patches, you can minimize merge conflicts prior to rebasing to include this commit by running `black` on your modified Python files, squashing, and then rebasing/merging.pull/3245/head
parent
5d4b803914
commit
6877302504
|
@ -30,6 +30,7 @@ if not TORCH_INCLUDE_DIR.is_dir():
|
||||||
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
||||||
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
def reindent(text, prefix=""):
|
def reindent(text, prefix=""):
|
||||||
return indent(dedent(text), prefix)
|
return indent(dedent(text), prefix)
|
||||||
|
|
||||||
|
@ -75,7 +76,11 @@ class GenMlirLazyIr(torchgen.dest.GenLazyIR):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only create this variable if it's used to avoid Wunused-variable
|
# Only create this variable if it's used to avoid Wunused-variable
|
||||||
operand_idx_counter = "size_t i = 0;" if "i++" in (emplace_arguments_str + emplace_kwarguments) else ""
|
operand_idx_counter = (
|
||||||
|
"size_t i = 0;"
|
||||||
|
if "i++" in (emplace_arguments_str + emplace_kwarguments)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
return reindent(
|
return reindent(
|
||||||
f"""
|
f"""
|
||||||
|
@ -111,12 +116,16 @@ class GenTorchMlirLTC:
|
||||||
)
|
)
|
||||||
assert self.torch_ops_file.exists()
|
assert self.torch_ops_file.exists()
|
||||||
self.binary_dir = Path(binary_dir)
|
self.binary_dir = Path(binary_dir)
|
||||||
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
|
assert (
|
||||||
|
self.binary_dir.is_dir()
|
||||||
|
), f"Binary directory not found: {self.binary_dir}"
|
||||||
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
||||||
self.backend_path = TORCH_MLIR_DIR.joinpath(
|
self.backend_path = TORCH_MLIR_DIR.joinpath(
|
||||||
"projects", "ltc", "csrc", "base_lazy_backend"
|
"projects", "ltc", "csrc", "base_lazy_backend"
|
||||||
)
|
)
|
||||||
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
|
assert (
|
||||||
|
self.backend_path.is_dir()
|
||||||
|
), f"Backend path not found: {self.backend_path}"
|
||||||
self.generated_path = self.binary_dir.joinpath(
|
self.generated_path = self.binary_dir.joinpath(
|
||||||
"projects", "ltc", "csrc", "base_lazy_backend", "generated"
|
"projects", "ltc", "csrc", "base_lazy_backend", "generated"
|
||||||
)
|
)
|
||||||
|
@ -168,8 +177,9 @@ class GenTorchMlirLTC:
|
||||||
if ts_native_yaml_path.exists():
|
if ts_native_yaml_path.exists():
|
||||||
ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader)
|
ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader)
|
||||||
else:
|
else:
|
||||||
logging.warning(f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}")
|
logging.warning(
|
||||||
|
f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}"
|
||||||
|
)
|
||||||
|
|
||||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||||
self.native_functions = parsed_yaml.native_functions
|
self.native_functions = parsed_yaml.native_functions
|
||||||
|
|
|
@ -9,19 +9,20 @@ import requests
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('owner', type=str)
|
parser.add_argument("owner", type=str)
|
||||||
parser.add_argument('repo', type=str)
|
parser.add_argument("repo", type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Get releases
|
# Get releases
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"https://api.github.com/repos/{args.owner}/{args.repo}/releases")
|
f"https://api.github.com/repos/{args.owner}/{args.repo}/releases"
|
||||||
|
)
|
||||||
body = json.loads(response.content)
|
body = json.loads(response.content)
|
||||||
|
|
||||||
# Parse releases
|
# Parse releases
|
||||||
releases = []
|
releases = []
|
||||||
for row in body:
|
for row in body:
|
||||||
for asset in row['assets']:
|
for asset in row["assets"]:
|
||||||
releases.append((asset["name"], asset["browser_download_url"]))
|
releases.append((asset["name"], asset["browser_download_url"]))
|
||||||
|
|
||||||
# Output HTML
|
# Output HTML
|
||||||
|
|
|
@ -25,10 +25,18 @@ from torch_mlir_e2e_test.configs import (
|
||||||
FxImporterTestConfig,
|
FxImporterTestConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||||
from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import LinalgOnTensorsOnnxBackend
|
RefBackendLinalgOnTensorsBackend,
|
||||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
)
|
||||||
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import (
|
||||||
|
LinalgOnTensorsOnnxBackend,
|
||||||
|
)
|
||||||
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import (
|
||||||
|
LinalgOnTensorsTosaBackend,
|
||||||
|
)
|
||||||
|
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import (
|
||||||
|
LinalgOnTensorsStablehloBackend,
|
||||||
|
)
|
||||||
|
|
||||||
from .xfail_sets import (
|
from .xfail_sets import (
|
||||||
LINALG_XFAIL_SET,
|
LINALG_XFAIL_SET,
|
||||||
|
@ -51,13 +59,28 @@ from .xfail_sets import (
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||||
|
|
||||||
register_all_tests()
|
register_all_tests()
|
||||||
|
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core",
|
config_choices = [
|
||||||
"torchdynamo", "onnx", "fx_importer", "fx_importer_stablehlo"]
|
"native_torch",
|
||||||
|
"torchscript",
|
||||||
|
"linalg",
|
||||||
|
"stablehlo",
|
||||||
|
"make_fx_tosa",
|
||||||
|
"tosa",
|
||||||
|
"lazy_tensor_core",
|
||||||
|
"torchdynamo",
|
||||||
|
"onnx",
|
||||||
|
"fx_importer",
|
||||||
|
"fx_importer_stablehlo",
|
||||||
|
]
|
||||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||||
parser.add_argument("-c", "--config",
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--config",
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
default="linalg",
|
default="linalg",
|
||||||
help=f"""
|
help=f"""
|
||||||
|
@ -72,34 +95,52 @@ Meaning of options:
|
||||||
"onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path.
|
"onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path.
|
||||||
"fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors.
|
"fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors.
|
||||||
"fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend.
|
"fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend.
|
||||||
""")
|
""",
|
||||||
parser.add_argument("-f", "--filter", default=".*", help="""
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--filter",
|
||||||
|
default=".*",
|
||||||
|
help="""
|
||||||
Regular expression specifying which tests to include in this run.
|
Regular expression specifying which tests to include in this run.
|
||||||
""")
|
""",
|
||||||
parser.add_argument("-v", "--verbose",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="report test results with additional detail")
|
help="report test results with additional detail",
|
||||||
parser.add_argument("-s", "--sequential",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--sequential",
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="""Run tests sequentially rather than in parallel.
|
help="""Run tests sequentially rather than in parallel.
|
||||||
This can be useful for debugging, since it runs the tests in the same process,
|
This can be useful for debugging, since it runs the tests in the same process,
|
||||||
which make it easier to attach a debugger or get a stack trace.""")
|
which make it easier to attach a debugger or get a stack trace.""",
|
||||||
parser.add_argument("--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed",
|
)
|
||||||
metavar="TEST", type=str, nargs="+",
|
parser.add_argument(
|
||||||
help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.")
|
"--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed",
|
||||||
parser.add_argument("--ignore_failures",
|
metavar="TEST",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ignore_failures",
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="return exit code 0 even if the test fails to unblock pipeline")
|
help="return exit code 0 even if the test fails to unblock pipeline",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = _get_argparse().parse_args()
|
args = _get_argparse().parse_args()
|
||||||
|
|
||||||
all_test_unique_names = set(
|
all_test_unique_names = set(test.unique_name for test in GLOBAL_TEST_REGISTRY)
|
||||||
test.unique_name for test in GLOBAL_TEST_REGISTRY)
|
|
||||||
|
|
||||||
# Find the selected config.
|
# Find the selected config.
|
||||||
if args.config == "linalg":
|
if args.config == "linalg":
|
||||||
|
@ -147,23 +188,26 @@ def main():
|
||||||
xfail_set = ONNX_XFAIL_SET
|
xfail_set = ONNX_XFAIL_SET
|
||||||
crashing_set = ONNX_CRASHING_SET
|
crashing_set = ONNX_CRASHING_SET
|
||||||
|
|
||||||
do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set)
|
do_not_attempt = set(
|
||||||
available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt]
|
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []
|
||||||
|
).union(crashing_set)
|
||||||
|
available_tests = [
|
||||||
|
test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt
|
||||||
|
]
|
||||||
if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None:
|
if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None:
|
||||||
for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed:
|
for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed:
|
||||||
if arg not in all_test_unique_names:
|
if arg not in all_test_unique_names:
|
||||||
print(f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name")
|
print(
|
||||||
|
f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Find the selected tests, and emit a diagnostic if none are found.
|
# Find the selected tests, and emit a diagnostic if none are found.
|
||||||
tests = [
|
tests = [
|
||||||
test for test in available_tests
|
test for test in available_tests if re.match(args.filter, test.unique_name)
|
||||||
if re.match(args.filter, test.unique_name)
|
|
||||||
]
|
]
|
||||||
if len(tests) == 0:
|
if len(tests) == 0:
|
||||||
print(
|
print(f"ERROR: the provided filter {args.filter!r} does not match any tests")
|
||||||
f"ERROR: the provided filter {args.filter!r} does not match any tests"
|
|
||||||
)
|
|
||||||
print("The available tests are:")
|
print("The available tests are:")
|
||||||
for test in available_tests:
|
for test in available_tests:
|
||||||
print(test.unique_name)
|
print(test.unique_name)
|
||||||
|
@ -175,18 +219,25 @@ def main():
|
||||||
# Report the test results.
|
# Report the test results.
|
||||||
failed = report_results(results, xfail_set, args.verbose, args.config)
|
failed = report_results(results, xfail_set, args.verbose, args.config)
|
||||||
if args.config == "torchdynamo":
|
if args.config == "torchdynamo":
|
||||||
print("\033[91mWarning: the TorchScript based dynamo support is deprecated. "
|
print(
|
||||||
"The config for torchdynamo is planned to be removed in the future.\033[0m")
|
"\033[91mWarning: the TorchScript based dynamo support is deprecated. "
|
||||||
|
"The config for torchdynamo is planned to be removed in the future.\033[0m"
|
||||||
|
)
|
||||||
if args.ignore_failures:
|
if args.ignore_failures:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
sys.exit(1 if failed else 0)
|
sys.exit(1 if failed else 0)
|
||||||
|
|
||||||
|
|
||||||
def _suppress_warnings():
|
def _suppress_warnings():
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# Ignore warning due to Python bug:
|
# Ignore warning due to Python bug:
|
||||||
# https://stackoverflow.com/questions/4964101/pep-3118-warning-when-using-ctypes-array-as-numpy-array
|
# https://stackoverflow.com/questions/4964101/pep-3118-warning-when-using-ctypes-array-as-numpy-array
|
||||||
warnings.filterwarnings("ignore",
|
warnings.filterwarnings(
|
||||||
message="A builtin ctypes object gave a PEP3118 format string that does not match its itemsize")
|
"ignore",
|
||||||
|
message="A builtin ctypes object gave a PEP3118 format string that does not match its itemsize",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_suppress_warnings()
|
_suppress_warnings()
|
||||||
|
|
|
@ -31,21 +31,17 @@ LINALG_CRASHING_SET = {
|
||||||
|
|
||||||
TORCHDYNAMO_XFAIL_SET = {
|
TORCHDYNAMO_XFAIL_SET = {
|
||||||
#### General TorchDynamo/PyTorch errors
|
#### General TorchDynamo/PyTorch errors
|
||||||
|
|
||||||
# torch._dynamo.exc.Unsupported: Tensor.item
|
# torch._dynamo.exc.Unsupported: Tensor.item
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
|
|
||||||
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
|
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
|
||||||
# RuntimeError: Failed running call_function aten.convolution_backward(...
|
# RuntimeError: Failed running call_function aten.convolution_backward(...
|
||||||
# https://github.com/pytorch/pytorch/issues/89629
|
# https://github.com/pytorch/pytorch/issues/89629
|
||||||
"ConvolutionBackwardModule2DPadded_basic",
|
"ConvolutionBackwardModule2DPadded_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
|
|
||||||
# Size result mismatch (exposed by downstream canonicalizer
|
# Size result mismatch (exposed by downstream canonicalizer
|
||||||
# on incompatabile casts).
|
# on incompatabile casts).
|
||||||
# https://github.com/pytorch/pytorch/issues/119407
|
# https://github.com/pytorch/pytorch/issues/119407
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
|
|
||||||
# RuntimeError: Index tensor must have the same number of dimensions as self tensor
|
# RuntimeError: Index tensor must have the same number of dimensions as self tensor
|
||||||
# RuntimeError: Failed running call_function aten.nll_loss_backward(...
|
# RuntimeError: Failed running call_function aten.nll_loss_backward(...
|
||||||
# https://github.com/pytorch/pytorch/issues/89630
|
# https://github.com/pytorch/pytorch/issues/89630
|
||||||
|
@ -59,196 +55,159 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# RuntimeError: Failed running call_function aten.uniform(...
|
# RuntimeError: Failed running call_function aten.uniform(...
|
||||||
# https://github.com/pytorch/torchdynamo/issues/1954
|
# https://github.com/pytorch/torchdynamo/issues/1954
|
||||||
"UniformNoCorrelationModule_basic",
|
"UniformNoCorrelationModule_basic",
|
||||||
|
|
||||||
#### Torch-MLIR internal compiler errors
|
#### Torch-MLIR internal compiler errors
|
||||||
|
|
||||||
# These are probably due to slightly different ops being recorded by
|
# These are probably due to slightly different ops being recorded by
|
||||||
# torchdynamo vs. torchscript.
|
# torchdynamo vs. torchscript.
|
||||||
|
|
||||||
# No upstream decompositions.
|
# No upstream decompositions.
|
||||||
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
||||||
# See also: https://github.com/pytorch/torchdynamo/issues/327
|
# See also: https://github.com/pytorch/torchdynamo/issues/327
|
||||||
"AtenEmbeddingBagSumExample_basic",
|
"AtenEmbeddingBagSumExample_basic",
|
||||||
|
|
||||||
# error: unsupported by backend contract: tensor with unknown rank
|
# error: unsupported by backend contract: tensor with unknown rank
|
||||||
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
|
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
|
||||||
"ElementwisePreluModule_basic",
|
"ElementwisePreluModule_basic",
|
||||||
# error: torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: AssertionError: Unregistered operation: torch.aten._prelu_kernel
|
# error: torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: AssertionError: Unregistered operation: torch.aten._prelu_kernel
|
||||||
"ElementwisePreluStaticModule_basic",
|
"ElementwisePreluStaticModule_basic",
|
||||||
|
# ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
|
||||||
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
|
|
||||||
"UpSampleNearest2dDynamicFactor_basic",
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
"ReduceMinAlongDimUnsignedInt_basic",
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
#ERROR: value (-56) is not equal to golden value (200)
|
# ERROR: value (-56) is not equal to golden value (200)
|
||||||
"AtenIntTensorByteDtypeModule_basic",
|
"AtenIntTensorByteDtypeModule_basic",
|
||||||
# ERROR: assert isinstance(e, FakeTensor)
|
# ERROR: assert isinstance(e, FakeTensor)
|
||||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
# ERROR: assert isinstance(e, FakeTensor)
|
# ERROR: assert isinstance(e, FakeTensor)
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
|
|
||||||
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::squeeze.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
|
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::squeeze.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
|
||||||
"PrimsSqueezeModule_basic",
|
"PrimsSqueezeModule_basic",
|
||||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||||
"SplitDimStaticModule_basic",
|
"SplitDimStaticModule_basic",
|
||||||
"SplitDimDynamicModule_basic",
|
"SplitDimDynamicModule_basic",
|
||||||
|
|
||||||
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::view_of.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
|
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::view_of.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
|
||||||
"PrimsViewOfModule_basic",
|
"PrimsViewOfModule_basic",
|
||||||
"PrimsViewOfZeroRankModule_basic",
|
"PrimsViewOfZeroRankModule_basic",
|
||||||
|
|
||||||
# See https://github.com/llvm/torch-mlir/pull/2040 and corresponding upstream issue
|
# See https://github.com/llvm/torch-mlir/pull/2040 and corresponding upstream issue
|
||||||
# https://github.com/pytorch/pytorch/issues/99752.
|
# https://github.com/pytorch/pytorch/issues/99752.
|
||||||
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(bool) [TensorVariable()] {}
|
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(bool) [TensorVariable()] {}
|
||||||
'TensorToBoolZeroRank_basic',
|
"TensorToBoolZeroRank_basic",
|
||||||
'TensorToBool_basic',
|
"TensorToBool_basic",
|
||||||
|
|
||||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||||
'AtenSubFloatModule_basic',
|
"AtenSubFloatModule_basic",
|
||||||
'AtenMulFloatModule_basic',
|
"AtenMulFloatModule_basic",
|
||||||
'BoolFloatFalseModule_basic',
|
"BoolFloatFalseModule_basic",
|
||||||
'BoolFloatTrueModule_basic',
|
"BoolFloatTrueModule_basic",
|
||||||
'CeilFloatModule_basic',
|
"CeilFloatModule_basic",
|
||||||
'DivFloatModule_basic',
|
"DivFloatModule_basic",
|
||||||
'GeFloatIntModule_basic',
|
"GeFloatIntModule_basic",
|
||||||
'GeFloatModule_basic',
|
"GeFloatModule_basic",
|
||||||
'GtFloatIntModule_basic',
|
"GtFloatIntModule_basic",
|
||||||
'NeFloatIntModule_basic',
|
"NeFloatIntModule_basic",
|
||||||
'SubFloatModule_basic',
|
"SubFloatModule_basic",
|
||||||
'MulFloatModule_basic',
|
"MulFloatModule_basic",
|
||||||
'TensorToFloatZeroRank_basic',
|
"TensorToFloatZeroRank_basic",
|
||||||
'TensorToFloat_basic',
|
"TensorToFloat_basic",
|
||||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||||
|
|
||||||
|
|
||||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||||
'AddIntModule_basic',
|
"AddIntModule_basic",
|
||||||
'AtenIntTensorCharDtypeModule_basic',
|
"AtenIntTensorCharDtypeModule_basic",
|
||||||
'BoolIntFalseModule_basic',
|
"BoolIntFalseModule_basic",
|
||||||
'BoolIntTrueModule_basic',
|
"BoolIntTrueModule_basic",
|
||||||
'DivIntModule_basic',
|
"DivIntModule_basic",
|
||||||
'EqIntModule_basic',
|
"EqIntModule_basic",
|
||||||
'GeIntModule_basic',
|
"GeIntModule_basic",
|
||||||
'GtIntModule_basic',
|
"GtIntModule_basic",
|
||||||
'MulIntModule_basic',
|
"MulIntModule_basic",
|
||||||
'NeIntModule_basic',
|
"NeIntModule_basic",
|
||||||
'SqrtIntModule_basic',
|
"SqrtIntModule_basic",
|
||||||
'SubIntModule_basic',
|
"SubIntModule_basic",
|
||||||
'TensorToIntZeroRank_basic',
|
"TensorToIntZeroRank_basic",
|
||||||
'TensorToInt_basic',
|
"TensorToInt_basic",
|
||||||
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
|
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: Tensor.item
|
# ERROR: torch._dynamo.exc.Unsupported: Tensor.item
|
||||||
'AtenItemIntOpModule_basic',
|
"AtenItemIntOpModule_basic",
|
||||||
'AtenItemFpOpModule_basic',
|
"AtenItemFpOpModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
|
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
|
||||||
'SortIntListReverse_basic',
|
"SortIntListReverse_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
|
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
|
||||||
'SortIntList_basic',
|
"SortIntList_basic",
|
||||||
|
|
||||||
# START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
# START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||||
'AtenFloatScalarModule_basic',
|
"AtenFloatScalarModule_basic",
|
||||||
'AtenIntBoolOpModule_basic',
|
"AtenIntBoolOpModule_basic",
|
||||||
'QuantizedMLP_basic',
|
"QuantizedMLP_basic",
|
||||||
'QuantizedSingleLayer_basic',
|
"QuantizedSingleLayer_basic",
|
||||||
'QuantizedBatchedInputSingleLayer_basic',
|
"QuantizedBatchedInputSingleLayer_basic",
|
||||||
'QuantizedNoLayer_basic',
|
"QuantizedNoLayer_basic",
|
||||||
'ScalarImplicitFloatModule_basic',
|
"ScalarImplicitFloatModule_basic",
|
||||||
'ScalarImplicitIntModule_basic',
|
"ScalarImplicitIntModule_basic",
|
||||||
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||||
|
|
||||||
# START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
# START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
||||||
'BincountMinlengthModule_basic',
|
"BincountMinlengthModule_basic",
|
||||||
'BincountModule_basic',
|
"BincountModule_basic",
|
||||||
'BincountStaticSizeModule_basic',
|
"BincountStaticSizeModule_basic",
|
||||||
# END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
# END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
|
||||||
'BoolFloatConstantModule_basic',
|
"BoolFloatConstantModule_basic",
|
||||||
'BoolIntConstantModule_basic',
|
"BoolIntConstantModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
|
||||||
'ContainsIntList_False',
|
"ContainsIntList_False",
|
||||||
'ContainsIntList_True',
|
"ContainsIntList_True",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
|
||||||
'AllBoolFalseModule_basic',
|
"AllBoolFalseModule_basic",
|
||||||
'AllBoolTrueModule_basic',
|
"AllBoolTrueModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
|
||||||
'AnyBoolFalseModule_basic',
|
"AnyBoolFalseModule_basic",
|
||||||
'AnyBoolTrueModule_basic',
|
"AnyBoolTrueModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
|
||||||
'SqrtIntConstantModule_basic',
|
"SqrtIntConstantModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
|
||||||
'BroadcastDynamicDimModule_basic',
|
"BroadcastDynamicDimModule_basic",
|
||||||
|
|
||||||
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||||
'AtenIntBoolOpConstFalseModule_basic',
|
"AtenIntBoolOpConstFalseModule_basic",
|
||||||
'AtenIntBoolOpConstTrueModule_basic',
|
"AtenIntBoolOpConstTrueModule_basic",
|
||||||
'IntFloatModule_basic',
|
"IntFloatModule_basic",
|
||||||
'PowIntFloatModule_basic',
|
"PowIntFloatModule_basic",
|
||||||
# END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
# END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
|
||||||
'LenStrModule_basic',
|
"LenStrModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
|
||||||
'NumelModule_basic',
|
"NumelModule_basic",
|
||||||
'NumelZeroRankModule_basic',
|
"NumelZeroRankModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
|
||||||
'PrimMaxIntModule_basic',
|
"PrimMaxIntModule_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
|
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
|
||||||
'PrimMinIntModule_basic',
|
"PrimMinIntModule_basic",
|
||||||
'PrimMinIntDynamicModule_basic',
|
"PrimMinIntDynamicModule_basic",
|
||||||
|
|
||||||
# START tests failing due to: empty graph in dynamo
|
# START tests failing due to: empty graph in dynamo
|
||||||
'IsFloatingPointFloat_True',
|
"IsFloatingPointFloat_True",
|
||||||
'IsFloatingPointInt_False',
|
"IsFloatingPointInt_False",
|
||||||
'TorchPrimLoopForLikeModule_basic',
|
"TorchPrimLoopForLikeModule_basic",
|
||||||
'TorchPrimLoopWhileLikeModule_basic',
|
"TorchPrimLoopWhileLikeModule_basic",
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
# END tests failing due to: empty graph in dynamo
|
# END tests failing due to: empty graph in dynamo
|
||||||
|
|
||||||
# ERROR due to: backend never runs because of empty frame
|
# ERROR due to: backend never runs because of empty frame
|
||||||
'ConstantBoolParameterModule_basic',
|
"ConstantBoolParameterModule_basic",
|
||||||
|
|
||||||
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
"UpSampleNearest2dDynamicSize_basic",
|
"UpSampleNearest2dDynamicSize_basic",
|
||||||
"UpSampleNearest2dStaticFactor_basic",
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
"UpSampleNearest2d_basic",
|
"UpSampleNearest2d_basic",
|
||||||
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
|
|
||||||
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
"ElementwiseAddScalarFloatModule_basic",
|
"ElementwiseAddScalarFloatModule_basic",
|
||||||
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
|
|
||||||
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||||
"HBC_basic",
|
"HBC_basic",
|
||||||
|
|
||||||
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
"ElementwiseDivScalarModule_basic",
|
"ElementwiseDivScalarModule_basic",
|
||||||
|
|
||||||
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||||
"ElementwiseAtenDivIntScalarModule_basic",
|
"ElementwiseAtenDivIntScalarModule_basic",
|
||||||
|
|
||||||
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
"ElementwiseSubScalarFloatModule_basic",
|
"ElementwiseSubScalarFloatModule_basic",
|
||||||
"ElementwiseSubScalarIntModule_basic",
|
"ElementwiseSubScalarIntModule_basic",
|
||||||
|
|
||||||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||||
|
@ -258,57 +217,43 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||||
|
|
||||||
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
"AdaptiveAvgPool2dDynamic_basic",
|
"AdaptiveAvgPool2dDynamic_basic",
|
||||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||||
|
|
||||||
# ERROR: Exception: Unsupported op: get_attr
|
# ERROR: Exception: Unsupported op: get_attr
|
||||||
"NumToTensorFloatModule_basic",
|
"NumToTensorFloatModule_basic",
|
||||||
"NumToTensorIntModule_basic",
|
"NumToTensorIntModule_basic",
|
||||||
"TensorFloatModule_basic",
|
"TensorFloatModule_basic",
|
||||||
"TensorIntModule_basic",
|
"TensorIntModule_basic",
|
||||||
|
|
||||||
# START tests failing due to: complex floating point ops
|
# START tests failing due to: complex floating point ops
|
||||||
# END tests failing due to: complex floating point ops
|
# END tests failing due to: complex floating point ops
|
||||||
|
|
||||||
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
|
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
|
||||||
"UnbindIntListUnpack_Module_basic",
|
"UnbindIntListUnpack_Module_basic",
|
||||||
"UnbindIntGetItem_Module_basic",
|
"UnbindIntGetItem_Module_basic",
|
||||||
|
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||||
"ScatterValueFloatModule_basic",
|
"ScatterValueFloatModule_basic",
|
||||||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||||
"ScatterValueIntModule_basic",
|
"ScatterValueIntModule_basic",
|
||||||
|
|
||||||
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
|
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
|
||||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||||
|
|
||||||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
|
|
||||||
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
|
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
|
||||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||||
"AtenEmbeddingBagStaticModule_basic",
|
"AtenEmbeddingBagStaticModule_basic",
|
||||||
|
|
||||||
# Lowering not present for this case
|
# Lowering not present for this case
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
|
|
||||||
# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
|
# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
|
||||||
"ElementwiseAddScalarInt8Module_basic",
|
"ElementwiseAddScalarInt8Module_basic",
|
||||||
|
|
||||||
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
|
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
|
||||||
"ThresholdBackward2dMixedModule_basic",
|
"ThresholdBackward2dMixedModule_basic",
|
||||||
|
|
||||||
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
|
|
||||||
# Dynamo does not support tracing quantized tensors
|
# Dynamo does not support tracing quantized tensors
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
|
@ -327,13 +272,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"QuantizedReluInt8_basic",
|
"QuantizedReluInt8_basic",
|
||||||
"QuantizedReluUint8_basic",
|
"QuantizedReluUint8_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
|
||||||
# Dynamo not supporting conv_tbc
|
# Dynamo not supporting conv_tbc
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
|
|
||||||
"FloatImplicitModule_basic",
|
"FloatImplicitModule_basic",
|
||||||
"IntImplicitModule_basic",
|
"IntImplicitModule_basic",
|
||||||
|
|
||||||
# Others
|
# Others
|
||||||
"ExponentialModule_basic",
|
"ExponentialModule_basic",
|
||||||
"GridSamplerBasic1_basic",
|
"GridSamplerBasic1_basic",
|
||||||
|
@ -383,142 +325,141 @@ TORCHDYNAMO_CRASHING_SET = {
|
||||||
"MaxPool3dModule_basic",
|
"MaxPool3dModule_basic",
|
||||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool3dStaticModule_basic",
|
"MaxPool3dStaticModule_basic",
|
||||||
|
|
||||||
# Looks like incorrect fx graph conversion
|
# Looks like incorrect fx graph conversion
|
||||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_XFAIL_SET = {
|
FX_IMPORTER_XFAIL_SET = {
|
||||||
'AllBoolFalseModule_basic',
|
"AllBoolFalseModule_basic",
|
||||||
'AllBoolTrueModule_basic',
|
"AllBoolTrueModule_basic",
|
||||||
'AnyBoolFalseModule_basic',
|
"AnyBoolFalseModule_basic",
|
||||||
'AnyBoolTrueModule_basic',
|
"AnyBoolTrueModule_basic",
|
||||||
'ArangeStartOutViewModule_basic',
|
"ArangeStartOutViewModule_basic",
|
||||||
'AtenEmbeddingBagStaticModule_basic',
|
"AtenEmbeddingBagStaticModule_basic",
|
||||||
'AtenEmbeddingBagSumExample_basic',
|
"AtenEmbeddingBagSumExample_basic",
|
||||||
'AtenFloatScalarModule_basic',
|
"AtenFloatScalarModule_basic",
|
||||||
'AtenIntBoolOpConstFalseModule_basic',
|
"AtenIntBoolOpConstFalseModule_basic",
|
||||||
'AtenIntBoolOpConstTrueModule_basic',
|
"AtenIntBoolOpConstTrueModule_basic",
|
||||||
'AtenIntBoolOpModule_basic',
|
"AtenIntBoolOpModule_basic",
|
||||||
'AtenItemFpOpModule_basic',
|
"AtenItemFpOpModule_basic",
|
||||||
'AtenMatmulQMixedSigni8Transpose_basic',
|
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||||
'AtenMatmulQMixedSigni8_basic',
|
"AtenMatmulQMixedSigni8_basic",
|
||||||
'AtenMatmulQint8MV_basic',
|
"AtenMatmulQint8MV_basic",
|
||||||
'AtenMatmulQint8_basic',
|
"AtenMatmulQint8_basic",
|
||||||
'AtenMatmulQint8VM_basic',
|
"AtenMatmulQint8VM_basic",
|
||||||
'AtenMatmulQint8VV_basic',
|
"AtenMatmulQint8VV_basic",
|
||||||
'AtenMmQMixedSigni8_basic',
|
"AtenMmQMixedSigni8_basic",
|
||||||
'AtenMmQint8_basic',
|
"AtenMmQint8_basic",
|
||||||
'AtenMmQuint8_basic',
|
"AtenMmQuint8_basic",
|
||||||
"QuantizedReluInt32_basic",
|
"QuantizedReluInt32_basic",
|
||||||
"QuantizedReluInt8_basic",
|
"QuantizedReluInt8_basic",
|
||||||
"QuantizedReluUint8_basic",
|
"QuantizedReluUint8_basic",
|
||||||
'AtenSubFloatModule_basic',
|
"AtenSubFloatModule_basic",
|
||||||
'BincountMinlengthModule_basic',
|
"BincountMinlengthModule_basic",
|
||||||
'BincountModule_basic',
|
"BincountModule_basic",
|
||||||
'BincountStaticSizeModule_basic',
|
"BincountStaticSizeModule_basic",
|
||||||
'BoolFloatConstantModule_basic',
|
"BoolFloatConstantModule_basic",
|
||||||
'BoolFloatFalseModule_basic',
|
"BoolFloatFalseModule_basic",
|
||||||
'BoolFloatTrueModule_basic',
|
"BoolFloatTrueModule_basic",
|
||||||
'BoolIntConstantModule_basic',
|
"BoolIntConstantModule_basic",
|
||||||
'BoolIntFalseModule_basic',
|
"BoolIntFalseModule_basic",
|
||||||
'BoolIntTrueModule_basic',
|
"BoolIntTrueModule_basic",
|
||||||
'BroadcastDynamicDimModule_basic',
|
"BroadcastDynamicDimModule_basic",
|
||||||
'CeilFloatModule_basic',
|
"CeilFloatModule_basic",
|
||||||
'ConstantBoolParameterModule_basic',
|
"ConstantBoolParameterModule_basic",
|
||||||
'ContainsIntList_False',
|
"ContainsIntList_False",
|
||||||
'ContainsIntList_True',
|
"ContainsIntList_True",
|
||||||
'Conv2dQInt8Module_basic',
|
"Conv2dQInt8Module_basic",
|
||||||
'Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier',
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
'ConvTbcModule_basic',
|
"ConvTbcModule_basic",
|
||||||
'ConvolutionBackwardModule2DPadded_basic',
|
"ConvolutionBackwardModule2DPadded_basic",
|
||||||
'ConvolutionBackwardModule2DStrided_basic',
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
'ConvolutionBackwardModule2D_basic',
|
"ConvolutionBackwardModule2D_basic",
|
||||||
'CumsumModule_basic',
|
"CumsumModule_basic",
|
||||||
'DivFloatModule_basic',
|
"DivFloatModule_basic",
|
||||||
'DivIntModule_basic',
|
"DivIntModule_basic",
|
||||||
'ElementwiseAddScalar_NumToTensorFloat_Module_basic',
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
'ElementwiseDequantizePerChannelModule_basic',
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
'ElementwiseDequantizePerTensorModule_basic',
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
'ElementwiseQuantizePerTensorModule_basic',
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
'ElementwiseQuantizePerTensorUIntModule_basic',
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
'ElementwiseToDtypeI64ToUI8Module_basic',
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
'EqIntModule_basic',
|
"EqIntModule_basic",
|
||||||
'FakeQuantizePerTensorAffineDynamicShapeModule_basic',
|
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||||
'FakeQuantizePerTensorAffineModule_basic',
|
"FakeQuantizePerTensorAffineModule_basic",
|
||||||
'FakeQuantizePerTensorAffineRoundToEvenModule_basic',
|
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||||
'FloatImplicitModule_basic',
|
"FloatImplicitModule_basic",
|
||||||
'GeFloatIntModule_basic',
|
"GeFloatIntModule_basic",
|
||||||
'GeFloatModule_basic',
|
"GeFloatModule_basic",
|
||||||
'GeIntModule_basic',
|
"GeIntModule_basic",
|
||||||
'GtFloatIntModule_basic',
|
"GtFloatIntModule_basic",
|
||||||
'GtIntModule_basic',
|
"GtIntModule_basic",
|
||||||
'IntFloatModule_basic',
|
"IntFloatModule_basic",
|
||||||
'IntImplicitModule_basic',
|
"IntImplicitModule_basic",
|
||||||
'IsFloatingPointFloat_True',
|
"IsFloatingPointFloat_True",
|
||||||
'IsFloatingPointInt_False',
|
"IsFloatingPointInt_False",
|
||||||
'LenStrModule_basic',
|
"LenStrModule_basic",
|
||||||
'MaxPool3dCeilModeTrueModule_basic',
|
"MaxPool3dCeilModeTrueModule_basic",
|
||||||
'MaxPool3dEmptyStrideStaticModule_basic',
|
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||||
'MaxPool3dLargeDatadModule_basic',
|
"MaxPool3dLargeDatadModule_basic",
|
||||||
'MaxPool3dModuleRandomSimple_basic',
|
"MaxPool3dModuleRandomSimple_basic",
|
||||||
'MaxPool3dModule_basic',
|
"MaxPool3dModule_basic",
|
||||||
'MaxPool3dStaticCeilModeTrueModule_basic',
|
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||||
'MaxPool3dStaticModule_basic',
|
"MaxPool3dStaticModule_basic",
|
||||||
'MulFloatModule_basic',
|
"MulFloatModule_basic",
|
||||||
'NativeGroupNormBackwardModule_basic',
|
"NativeGroupNormBackwardModule_basic",
|
||||||
'NeFloatIntModule_basic',
|
"NeFloatIntModule_basic",
|
||||||
'NeIntModule_basic',
|
"NeIntModule_basic",
|
||||||
'NllLossModuleBackward1DMeanWeight_basic',
|
"NllLossModuleBackward1DMeanWeight_basic",
|
||||||
'NllLossModuleBackward1DMean_basic',
|
"NllLossModuleBackward1DMean_basic",
|
||||||
'NllLossModuleBackward1DSumWeight_basic',
|
"NllLossModuleBackward1DSumWeight_basic",
|
||||||
'NllLossModuleBackward1DSum_basic',
|
"NllLossModuleBackward1DSum_basic",
|
||||||
'NllLossModuleBackward1DWeight_basic',
|
"NllLossModuleBackward1DWeight_basic",
|
||||||
'NllLossModuleBackward1D_basic',
|
"NllLossModuleBackward1D_basic",
|
||||||
'NumToTensorFloatModule_basic',
|
"NumToTensorFloatModule_basic",
|
||||||
'NumToTensorIntModule_basic',
|
"NumToTensorIntModule_basic",
|
||||||
'NumelModule_basic',
|
"NumelModule_basic",
|
||||||
'NumelZeroRankModule_basic',
|
"NumelZeroRankModule_basic",
|
||||||
'PowIntFloatModule_basic',
|
"PowIntFloatModule_basic",
|
||||||
'PrimMaxIntModule_basic',
|
"PrimMaxIntModule_basic",
|
||||||
'PrimMinIntDynamicModule_basic',
|
"PrimMinIntDynamicModule_basic",
|
||||||
'PrimMinIntModule_basic',
|
"PrimMinIntModule_basic",
|
||||||
'PrimsSqueezeEmptyDimensionsModule_basic',
|
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||||
'PrimsSqueezeModule_basic',
|
"PrimsSqueezeModule_basic",
|
||||||
'PrimsViewOfModule_basic',
|
"PrimsViewOfModule_basic",
|
||||||
'PrimsViewOfZeroRankModule_basic',
|
"PrimsViewOfZeroRankModule_basic",
|
||||||
'QuantizedBatchedInputSingleLayer_basic',
|
"QuantizedBatchedInputSingleLayer_basic",
|
||||||
'QuantizedMLP_basic',
|
"QuantizedMLP_basic",
|
||||||
'QuantizedNoLayer_basic',
|
"QuantizedNoLayer_basic",
|
||||||
'QuantizedSingleLayer_basic',
|
"QuantizedSingleLayer_basic",
|
||||||
'ReduceMaxAlongDimUnsignedInt_basic',
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
'ReduceMinAlongDimUnsignedInt_basic',
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
'RsubInt0d_NumToTensor_Module_basic',
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
'ScalarConstantTupleModule_basic',
|
"ScalarConstantTupleModule_basic",
|
||||||
'ScalarImplicitFloatModule_basic',
|
"ScalarImplicitFloatModule_basic",
|
||||||
'SortIntListReverse_basic',
|
"SortIntListReverse_basic",
|
||||||
'SortIntList_basic',
|
"SortIntList_basic",
|
||||||
'SplitDimDynamicModule_basic',
|
"SplitDimDynamicModule_basic",
|
||||||
'SplitDimStaticModule_basic',
|
"SplitDimStaticModule_basic",
|
||||||
'SqrtIntConstantModule_basic',
|
"SqrtIntConstantModule_basic",
|
||||||
'SqrtIntModule_basic',
|
"SqrtIntModule_basic",
|
||||||
'SubFloatModule_basic',
|
"SubFloatModule_basic",
|
||||||
'TModuleRank0_basic',
|
"TModuleRank0_basic",
|
||||||
'TensorToBoolZeroRank_basic',
|
"TensorToBoolZeroRank_basic",
|
||||||
'TensorToBool_basic',
|
"TensorToBool_basic",
|
||||||
'TensorToFloatZeroRank_basic',
|
"TensorToFloatZeroRank_basic",
|
||||||
'TensorToFloat_basic',
|
"TensorToFloat_basic",
|
||||||
'TestMultipleTensorAndPrimitiveTypesReturn_basic',
|
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||||
'ThresholdBackward2dMixedModule_basic',
|
"ThresholdBackward2dMixedModule_basic",
|
||||||
'TorchPrimLoopForLikeModule_basic',
|
"TorchPrimLoopForLikeModule_basic",
|
||||||
'TorchPrimLoopWhileLikeModule_basic',
|
"TorchPrimLoopWhileLikeModule_basic",
|
||||||
'UnbindIntGetItem_Module_basic',
|
"UnbindIntGetItem_Module_basic",
|
||||||
'UnbindIntListUnpack_Module_basic',
|
"UnbindIntListUnpack_Module_basic",
|
||||||
'UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic',
|
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||||
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
|
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
'UpSampleNearest2dDynamicFactor_basic',
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
'ViewSizeFromOtherTensor_basic',
|
"ViewSizeFromOtherTensor_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_CRASHING_SET = {
|
FX_IMPORTER_CRASHING_SET = {
|
||||||
|
@ -1941,11 +1882,13 @@ TOSA_PASS_SET = {
|
||||||
"LinspaceModule_basic",
|
"LinspaceModule_basic",
|
||||||
"LinspaceOneSizeModule_basic",
|
"LinspaceOneSizeModule_basic",
|
||||||
"LinspaceTwoSizeModule_basic",
|
"LinspaceTwoSizeModule_basic",
|
||||||
"TorchPrimLoopForLikeTensorArgModule_basic"
|
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
MAKE_FX_TOSA_PASS_SET = (
|
||||||
### Tests additionally passing in make_fx_tosa
|
TOSA_PASS_SET
|
||||||
|
| {
|
||||||
|
### Tests additionally passing in make_fx_tosa
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||||
|
@ -1965,19 +1908,17 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||||
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
||||||
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
}) - {
|
}
|
||||||
### Test failing in make_fx_tosa but not in tosa
|
) - {
|
||||||
|
### Test failing in make_fx_tosa but not in tosa
|
||||||
# Dynamic shape, has extra unsupported broadcast ops
|
# Dynamic shape, has extra unsupported broadcast ops
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
|
|
||||||
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
||||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool2dStaticModule_basic",
|
"MaxPool2dStaticModule_basic",
|
||||||
"ResNet18StaticModule_basic",
|
"ResNet18StaticModule_basic",
|
||||||
|
|
||||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||||
|
@ -1986,18 +1927,14 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||||
# failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal
|
# failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal
|
||||||
"AtenEyeModuleInt2D_basic",
|
"AtenEyeModuleInt2D_basic",
|
||||||
"AtenEyeMModuleInt2D_basic",
|
"AtenEyeMModuleInt2D_basic",
|
||||||
|
|
||||||
"Conv2dBiasNoPaddingModule_basic",
|
"Conv2dBiasNoPaddingModule_basic",
|
||||||
"Conv2dNoPaddingModule_basic",
|
"Conv2dNoPaddingModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||||
"Conv2dWithPaddingModule_basic",
|
"Conv2dWithPaddingModule_basic",
|
||||||
|
|
||||||
"AtenInstanceNormModule_basic",
|
"AtenInstanceNormModule_basic",
|
||||||
|
|
||||||
# failed to legalize operation 'torch.operator'
|
# failed to legalize operation 'torch.operator'
|
||||||
"ElementwisePreluModule_basic",
|
"ElementwisePreluModule_basic",
|
||||||
"ElementwisePreluStaticModule_basic",
|
"ElementwisePreluStaticModule_basic",
|
||||||
|
|
||||||
# Shape Related failures
|
# Shape Related failures
|
||||||
"PrimListUnpackNumMismatchModule_basic",
|
"PrimListUnpackNumMismatchModule_basic",
|
||||||
"ReshapeExpandModule_basic",
|
"ReshapeExpandModule_basic",
|
||||||
|
@ -2019,8 +1956,7 @@ LTC_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
LTC_XFAIL_SET = {
|
LTC_XFAIL_SET = {
|
||||||
"TorchPrimLoopForLikeTensorArgModule_basic"
|
"TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic",
|
||||||
"CollapseAllDimensionsModule_basic",
|
|
||||||
"CollapseRank1DynamicModule_basic",
|
"CollapseRank1DynamicModule_basic",
|
||||||
"CollapseStaticModule_basic",
|
"CollapseStaticModule_basic",
|
||||||
"CollapsePartialDynamicModule_basic",
|
"CollapsePartialDynamicModule_basic",
|
||||||
|
@ -2162,7 +2098,6 @@ LTC_XFAIL_SET = {
|
||||||
ONNX_XFAIL_SET = {
|
ONNX_XFAIL_SET = {
|
||||||
# Failure - cast error
|
# Failure - cast error
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
|
||||||
# Failure - expand multiple dynamic dims
|
# Failure - expand multiple dynamic dims
|
||||||
"EmbeddingModuleF16_basic",
|
"EmbeddingModuleF16_basic",
|
||||||
"EmbeddingModuleI32_basic",
|
"EmbeddingModuleI32_basic",
|
||||||
|
@ -2174,7 +2109,6 @@ ONNX_XFAIL_SET = {
|
||||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||||
"IndexTensorSelectDimModule_basic",
|
"IndexTensorSelectDimModule_basic",
|
||||||
|
|
||||||
# Failure - incorrect numerics
|
# Failure - incorrect numerics
|
||||||
"AvgPool2dDivisorOverrideModule_basic",
|
"AvgPool2dDivisorOverrideModule_basic",
|
||||||
"BroadcastDynamicDimModule_basic",
|
"BroadcastDynamicDimModule_basic",
|
||||||
|
@ -2211,14 +2145,12 @@ ONNX_XFAIL_SET = {
|
||||||
"StdCorrectionLargeInputModule_basic",
|
"StdCorrectionLargeInputModule_basic",
|
||||||
"TupleModule_basic",
|
"TupleModule_basic",
|
||||||
"VarCorrectionLargeInputModule_basic",
|
"VarCorrectionLargeInputModule_basic",
|
||||||
|
|
||||||
# Failure - incorrect shape
|
# Failure - incorrect shape
|
||||||
"ArangeStartOutDtypeModule_basic",
|
"ArangeStartOutDtypeModule_basic",
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
"MoveDimIntNegativeIndexModule_basic",
|
"MoveDimIntNegativeIndexModule_basic",
|
||||||
"ReduceL3NormKeepDimModule_basic",
|
"ReduceL3NormKeepDimModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
|
||||||
# Failure - onnx_export
|
# Failure - onnx_export
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
@ -2619,10 +2551,8 @@ ONNX_XFAIL_SET = {
|
||||||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||||
"_SoftmaxModule_basic",
|
"_SoftmaxModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.AveragePool
|
# Failure - onnx_lowering: onnx.AveragePool
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.If
|
# Failure - onnx_lowering: onnx.If
|
||||||
"DiagonalModule_basic",
|
"DiagonalModule_basic",
|
||||||
"DiagonalModule_nonsquare",
|
"DiagonalModule_nonsquare",
|
||||||
|
@ -2633,12 +2563,10 @@ ONNX_XFAIL_SET = {
|
||||||
"DiagonalModule_with_offset",
|
"DiagonalModule_with_offset",
|
||||||
"TileBigDimsSizeModule_basic",
|
"TileBigDimsSizeModule_basic",
|
||||||
"TileSmallDimsSizeModule_basic",
|
"TileSmallDimsSizeModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.MaxPool
|
# Failure - onnx_lowering: onnx.MaxPool
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
"MaxPool2dWithIndicesStaticModule_basic",
|
"MaxPool2dWithIndicesStaticModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.ReduceProd
|
# Failure - onnx_lowering: onnx.ReduceProd
|
||||||
"ReduceProdFloatModule_basic",
|
"ReduceProdFloatModule_basic",
|
||||||
"ReduceProdDtypeFloatModule_basic",
|
"ReduceProdDtypeFloatModule_basic",
|
||||||
|
@ -2646,7 +2574,6 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceProdUnsignedIntModule_basic",
|
"ReduceProdUnsignedIntModule_basic",
|
||||||
"ReduceProdSignedIntModule_basic",
|
"ReduceProdSignedIntModule_basic",
|
||||||
"ReduceProdDtypeIntModule_basic",
|
"ReduceProdDtypeIntModule_basic",
|
||||||
|
|
||||||
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
|
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
|
||||||
"RandnDtypeDeviceModule_basic",
|
"RandnDtypeDeviceModule_basic",
|
||||||
"RandnGeneratorF64Module_basic",
|
"RandnGeneratorF64Module_basic",
|
||||||
|
@ -2656,21 +2583,17 @@ ONNX_XFAIL_SET = {
|
||||||
"BernoulliFloatModule_basic",
|
"BernoulliFloatModule_basic",
|
||||||
"BernoulliPModule_basic",
|
"BernoulliPModule_basic",
|
||||||
"BernoulliTensorModule_basic",
|
"BernoulliTensorModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.ReduceProd
|
# Failure - onnx_lowering: onnx.ReduceProd
|
||||||
"ReduceProdDimIntFloatModule_basic",
|
"ReduceProdDimIntFloatModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.Resize
|
# Failure - onnx_lowering: onnx.Resize
|
||||||
"UpSampleNearest2dDynamicSize_basic",
|
"UpSampleNearest2dDynamicSize_basic",
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.ScatterElements
|
# Failure - onnx_lowering: onnx.ScatterElements
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMinModuleIncludeSelf",
|
"ScatterReduceFloatMinModuleIncludeSelf",
|
||||||
"ScatterReduceIntMaxModuleIncludeSelf",
|
"ScatterReduceIntMaxModuleIncludeSelf",
|
||||||
"ScatterReduceIntMinModuleIncludeSelf",
|
"ScatterReduceIntMinModuleIncludeSelf",
|
||||||
"ScatterValueFloatModule_basic",
|
"ScatterValueFloatModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.ScatterND
|
# Failure - onnx_lowering: onnx.ScatterND
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||||
|
@ -2696,14 +2619,11 @@ ONNX_XFAIL_SET = {
|
||||||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
|
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
|
||||||
"CrossEntropyLossModule_basic",
|
"CrossEntropyLossModule_basic",
|
||||||
"CrossEntropyLossNoReductionModule_basic",
|
"CrossEntropyLossNoReductionModule_basic",
|
||||||
|
|
||||||
# RuntimeError: unsupported input type: Device
|
# RuntimeError: unsupported input type: Device
|
||||||
"PrimsIotaModule_basic",
|
"PrimsIotaModule_basic",
|
||||||
|
|
||||||
# Failure - unknown
|
# Failure - unknown
|
||||||
"BernoulliModule_basic",
|
"BernoulliModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
|
@ -2741,7 +2661,7 @@ if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
|
||||||
"ReduceL1NormWithDTypeModule_basic",
|
"ReduceL1NormWithDTypeModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse('2.3.0.dev'):
|
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||||
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
|
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
|
||||||
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
|
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
|
||||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||||
|
|
|
@ -23,34 +23,43 @@ import torch._lazy
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from datasets.dataset_dict import DatasetDict
|
from datasets.dataset_dict import DatasetDict
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import BertForSequenceClassification, \
|
from transformers import (
|
||||||
BertConfig, BertTokenizer, AdamW, get_scheduler
|
BertForSequenceClassification,
|
||||||
|
BertConfig,
|
||||||
|
BertTokenizer,
|
||||||
|
AdamW,
|
||||||
|
get_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
|
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
return tokenizer(examples["text"], padding="max_length",
|
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
||||||
truncation=True)
|
|
||||||
|
|
||||||
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
||||||
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
|
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
|
||||||
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
|
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
||||||
tokenized_datasets.set_format('torch')
|
tokenized_datasets.set_format("torch")
|
||||||
|
|
||||||
return tokenized_datasets
|
return tokenized_datasets
|
||||||
|
|
||||||
|
|
||||||
def train(model: BertForSequenceClassification,
|
def train(
|
||||||
|
model: BertForSequenceClassification,
|
||||||
num_epochs: int,
|
num_epochs: int,
|
||||||
num_training_steps: int,
|
num_training_steps: int,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
device: torch.device) -> List[torch.Tensor]:
|
device: torch.device,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
optimizer = AdamW(model.parameters(), lr=5e-5)
|
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||||
lr_scheduler = get_scheduler('linear', optimizer=optimizer,
|
lr_scheduler = get_scheduler(
|
||||||
|
"linear",
|
||||||
|
optimizer=optimizer,
|
||||||
num_warmup_steps=0,
|
num_warmup_steps=0,
|
||||||
num_training_steps=num_training_steps)
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
losses = []
|
losses = []
|
||||||
|
@ -66,14 +75,14 @@ def train(model: BertForSequenceClassification,
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if 'lazy' in str(model.device):
|
if "lazy" in str(model.device):
|
||||||
print("Calling Mark Step")
|
print("Calling Mark Step")
|
||||||
torch._lazy.mark_step()
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
|
||||||
def main(device='lazy', full_size=False):
|
def main(device="lazy", full_size=False):
|
||||||
"""
|
"""
|
||||||
Load model to specified device. Ensure that any backends have been initialized by this point.
|
Load model to specified device. Ensure that any backends have been initialized by this point.
|
||||||
|
|
||||||
|
@ -82,15 +91,14 @@ def main(device='lazy', full_size=False):
|
||||||
"""
|
"""
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
tokenized_datasets = tokenize_dataset(load_dataset('imdb'))
|
tokenized_datasets = tokenize_dataset(load_dataset("imdb"))
|
||||||
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \
|
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2))
|
||||||
.select(range(2))
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(small_train_dataset, shuffle=True,
|
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
|
||||||
batch_size=8)
|
|
||||||
if full_size:
|
if full_size:
|
||||||
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
|
model = BertForSequenceClassification.from_pretrained(
|
||||||
num_labels=2)
|
"bert-base-cased", num_labels=2
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
configuration = BertConfig(
|
configuration = BertConfig(
|
||||||
vocab_size=28996,
|
vocab_size=28996,
|
||||||
|
@ -98,7 +106,7 @@ def main(device='lazy', full_size=False):
|
||||||
num_hidden_layers=1,
|
num_hidden_layers=1,
|
||||||
num_attention_heads=2,
|
num_attention_heads=2,
|
||||||
intermediate_size=32,
|
intermediate_size=32,
|
||||||
hidden_act='gelu',
|
hidden_act="gelu",
|
||||||
hidden_dropout_prob=0.0,
|
hidden_dropout_prob=0.0,
|
||||||
attention_probs_dropout_prob=0.0,
|
attention_probs_dropout_prob=0.0,
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
|
@ -113,12 +121,12 @@ def main(device='lazy', full_size=False):
|
||||||
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
|
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
|
||||||
|
|
||||||
# Get debug information from LTC
|
# Get debug information from LTC
|
||||||
if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules:
|
if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules:
|
||||||
computation = lazy_backend.get_latest_computation()
|
computation = lazy_backend.get_latest_computation()
|
||||||
if computation:
|
if computation:
|
||||||
print(computation.debug_string())
|
print(computation.debug_string())
|
||||||
|
|
||||||
print('Loss: ', losses)
|
print("Loss: ", losses)
|
||||||
|
|
||||||
return model, losses
|
return model, losses
|
||||||
|
|
||||||
|
@ -136,7 +144,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-f",
|
"-f",
|
||||||
"--full_size",
|
"--full_size",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Use full sized BERT model instead of one with smaller parameterization",
|
help="Use full sized BERT model instead of one with smaller parameterization",
|
||||||
)
|
)
|
||||||
|
@ -145,6 +153,7 @@ if __name__ == "__main__":
|
||||||
if args.device in ("TS", "MLIR_EXAMPLE"):
|
if args.device in ("TS", "MLIR_EXAMPLE"):
|
||||||
if args.device == "TS":
|
if args.device == "TS":
|
||||||
import torch._lazy.ts_backend
|
import torch._lazy.ts_backend
|
||||||
|
|
||||||
torch._lazy.ts_backend.init()
|
torch._lazy.ts_backend.init()
|
||||||
|
|
||||||
elif args.device == "MLIR_EXAMPLE":
|
elif args.device == "MLIR_EXAMPLE":
|
||||||
|
|
|
@ -13,7 +13,7 @@ import torch._lazy
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def main(device='lazy'):
|
def main(device="lazy"):
|
||||||
"""
|
"""
|
||||||
Load model to specified device. Ensure that any backends have been initialized by this point.
|
Load model to specified device. Ensure that any backends have been initialized by this point.
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ def main(device='lazy'):
|
||||||
torch._lazy.mark_step()
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
# Get debug information from LTC
|
# Get debug information from LTC
|
||||||
if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules:
|
if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules:
|
||||||
computation = lazy_backend.get_latest_computation()
|
computation = lazy_backend.get_latest_computation()
|
||||||
if computation:
|
if computation:
|
||||||
print(computation.debug_string())
|
print(computation.debug_string())
|
||||||
|
@ -90,6 +90,7 @@ if __name__ == "__main__":
|
||||||
if args.device in ("TS", "MLIR_EXAMPLE"):
|
if args.device in ("TS", "MLIR_EXAMPLE"):
|
||||||
if args.device == "TS":
|
if args.device == "TS":
|
||||||
import torch._lazy.ts_backend
|
import torch._lazy.ts_backend
|
||||||
|
|
||||||
torch._lazy.ts_backend.init()
|
torch._lazy.ts_backend.init()
|
||||||
|
|
||||||
elif args.device == "MLIR_EXAMPLE":
|
elif args.device == "MLIR_EXAMPLE":
|
||||||
|
|
|
@ -21,19 +21,18 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||||
|
|
||||||
def load_and_preprocess_image(url: str):
|
def load_and_preprocess_image(url: str):
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent':
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
|
||||||
}
|
}
|
||||||
img = Image.open(requests.get(url, headers=headers,
|
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
|
||||||
stream=True).raw).convert("RGB")
|
|
||||||
# preprocessing pipeline
|
# preprocessing pipeline
|
||||||
preprocess = transforms.Compose([
|
preprocess = transforms.Compose(
|
||||||
|
[
|
||||||
transforms.Resize(256),
|
transforms.Resize(256),
|
||||||
transforms.CenterCrop(224),
|
transforms.CenterCrop(224),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
std=[0.229, 0.224, 0.225]),
|
]
|
||||||
])
|
)
|
||||||
img_preprocessed = preprocess(img)
|
img_preprocessed = preprocess(img)
|
||||||
return torch.unsqueeze(img_preprocessed, 0)
|
return torch.unsqueeze(img_preprocessed, 0)
|
||||||
|
|
||||||
|
@ -62,17 +61,23 @@ def predictions(torch_func, jit_func, img, labels):
|
||||||
print("torch-mlir prediction")
|
print("torch-mlir prediction")
|
||||||
print(prediction)
|
print(prediction)
|
||||||
|
|
||||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
|
||||||
|
image_url = (
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||||
|
)
|
||||||
|
|
||||||
print("load image from " + image_url, file=sys.stderr)
|
print("load image from " + image_url, file=sys.stderr)
|
||||||
img = load_and_preprocess_image(image_url)
|
img = load_and_preprocess_image(image_url)
|
||||||
labels = load_labels()
|
labels = load_labels()
|
||||||
|
|
||||||
|
|
||||||
@make_simple_dynamo_backend
|
@make_simple_dynamo_backend
|
||||||
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
def refbackend_torchdynamo_backend(
|
||||||
example_inputs: List[torch.Tensor]):
|
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||||
|
):
|
||||||
mlir_module = torchscript.compile(
|
mlir_module = torchscript.compile(
|
||||||
fx_graph, example_inputs, output_type="linalg-on-tensors")
|
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||||
|
)
|
||||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||||
compiled = backend.compile(mlir_module)
|
compiled = backend.compile(mlir_module)
|
||||||
loaded = backend.load(compiled)
|
loaded = backend.load(compiled)
|
||||||
|
@ -85,10 +90,17 @@ def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||||
else:
|
else:
|
||||||
result = tuple(torch.from_numpy(x) for x in result)
|
result = tuple(torch.from_numpy(x) for x in result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return compiled_callable
|
return compiled_callable
|
||||||
|
|
||||||
|
|
||||||
resnet18 = models.resnet18(pretrained=True)
|
resnet18 = models.resnet18(pretrained=True)
|
||||||
resnet18.train(False)
|
resnet18.train(False)
|
||||||
dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18)
|
dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18)
|
||||||
|
|
||||||
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), img, labels)
|
predictions(
|
||||||
|
resnet18.forward,
|
||||||
|
lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(),
|
||||||
|
img,
|
||||||
|
labels,
|
||||||
|
)
|
||||||
|
|
|
@ -18,19 +18,18 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||||
|
|
||||||
def load_and_preprocess_image(url: str):
|
def load_and_preprocess_image(url: str):
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent':
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
|
||||||
}
|
}
|
||||||
img = Image.open(requests.get(url, headers=headers,
|
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
|
||||||
stream=True).raw).convert("RGB")
|
|
||||||
# preprocessing pipeline
|
# preprocessing pipeline
|
||||||
preprocess = transforms.Compose([
|
preprocess = transforms.Compose(
|
||||||
|
[
|
||||||
transforms.Resize(256),
|
transforms.Resize(256),
|
||||||
transforms.CenterCrop(224),
|
transforms.CenterCrop(224),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
std=[0.229, 0.224, 0.225]),
|
]
|
||||||
])
|
)
|
||||||
img_preprocessed = preprocess(img)
|
img_preprocessed = preprocess(img)
|
||||||
return torch.unsqueeze(img_preprocessed, 0)
|
return torch.unsqueeze(img_preprocessed, 0)
|
||||||
|
|
||||||
|
@ -59,7 +58,10 @@ def predictions(torch_func, jit_func, img, labels):
|
||||||
print("torch-mlir prediction")
|
print("torch-mlir prediction")
|
||||||
print(prediction)
|
print(prediction)
|
||||||
|
|
||||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
|
||||||
|
image_url = (
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||||
|
)
|
||||||
|
|
||||||
print("load image from " + image_url, file=sys.stderr)
|
print("load image from " + image_url, file=sys.stderr)
|
||||||
img = load_and_preprocess_image(image_url)
|
img = load_and_preprocess_image(image_url)
|
||||||
|
@ -67,7 +69,9 @@ labels = load_labels()
|
||||||
|
|
||||||
resnet18 = models.resnet18(pretrained=True)
|
resnet18 = models.resnet18(pretrained=True)
|
||||||
resnet18.train(False)
|
resnet18.train(False)
|
||||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
module = torchscript.compile(
|
||||||
|
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
|
||||||
|
)
|
||||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||||
compiled = backend.compile(module)
|
compiled = backend.compile(module)
|
||||||
jit_module = backend.load(compiled)
|
jit_module = backend.load(compiled)
|
||||||
|
|
|
@ -13,8 +13,12 @@ resnet18.eval()
|
||||||
|
|
||||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||||
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
module = torchscript.compile(
|
||||||
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10)
|
||||||
|
)
|
||||||
# TODO: Debug why this is so slow.
|
# TODO: Debug why this is so slow.
|
||||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
|
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
|
||||||
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||||
|
|
|
@ -4,10 +4,12 @@ from torch_mlir import torchscript
|
||||||
|
|
||||||
model = models.resnet18(pretrained=True)
|
model = models.resnet18(pretrained=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
data = torch.randn(2,3,200,200)
|
data = torch.randn(2, 3, 200, 200)
|
||||||
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
||||||
|
|
||||||
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False)
|
module = torchscript.compile(
|
||||||
|
model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False
|
||||||
|
)
|
||||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
outf.write(str(module))
|
outf.write(str(module))
|
||||||
|
|
||||||
|
|
|
@ -7,17 +7,22 @@ from transformers import BertForMaskedLM
|
||||||
class BertTinyWrapper(torch.nn.Module):
|
class BertTinyWrapper(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False)
|
self.bert = BertForMaskedLM.from_pretrained(
|
||||||
|
"prajjwal1/bert-tiny", return_dict=False
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, data):
|
def forward(self, data):
|
||||||
return self.bert(data)[0]
|
return self.bert(data)[0]
|
||||||
|
|
||||||
|
|
||||||
model = BertTinyWrapper()
|
model = BertTinyWrapper()
|
||||||
model.eval()
|
model.eval()
|
||||||
data = torch.randint(30522, (2, 128))
|
data = torch.randint(30522, (2, 128))
|
||||||
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
||||||
|
|
||||||
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True)
|
module = torchscript.compile(
|
||||||
|
model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True
|
||||||
|
)
|
||||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
outf.write(str(module))
|
outf.write(str(module))
|
||||||
|
|
||||||
|
|
|
@ -11,18 +11,23 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
from torch_mlir.jit_ir_importer import ClassAnnotator
|
from torch_mlir.jit_ir_importer import ClassAnnotator
|
||||||
from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations
|
from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations
|
||||||
|
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3, 4], torch.float32, False),
|
([3, 4], torch.float32, False),
|
||||||
([4, 5], torch.float32, True),
|
([4, 5], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.mm(lhs, rhs)
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
module = MmModule()
|
module = MmModule()
|
||||||
annotator = ClassAnnotator()
|
annotator = ClassAnnotator()
|
||||||
extract_annotations(module, torch.jit.script(module), annotator)
|
extract_annotations(module, torch.jit.script(module), annotator)
|
||||||
|
|
|
@ -26,6 +26,8 @@ print(torchscript.compile(scripted, example_args))
|
||||||
scripted = torch.jit.script(BasicModule())
|
scripted = torch.jit.script(BasicModule())
|
||||||
try:
|
try:
|
||||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||||
torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
|
torchscript.compile(
|
||||||
|
scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
|
@ -8,10 +8,12 @@
|
||||||
import torch
|
import torch
|
||||||
from torch_mlir import torchscript
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
class BasicModule(torch.nn.Module):
|
class BasicModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.sin(x)
|
return torch.ops.aten.sin(x)
|
||||||
|
|
||||||
|
|
||||||
example_arg = torch.ones(2, 3)
|
example_arg = torch.ones(2, 3)
|
||||||
example_args = torchscript.ExampleArgs.get(example_arg)
|
example_args = torchscript.ExampleArgs.get(example_arg)
|
||||||
|
|
||||||
|
@ -23,6 +25,8 @@ print(torchscript.compile(traced, example_args))
|
||||||
traced = torch.jit.trace(BasicModule(), example_arg)
|
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||||
try:
|
try:
|
||||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||||
torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg))
|
torchscript.compile(
|
||||||
|
traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
|
@ -9,15 +9,24 @@ import torch
|
||||||
|
|
||||||
from torch_mlir import torchscript
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
class AddmmModule(torch.nn.Module):
|
class AddmmModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x, y, z):
|
def forward(self, x, y, z):
|
||||||
return torch.ops.aten.addmm(x, y, z)
|
return torch.ops.aten.addmm(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]
|
example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]
|
||||||
|
|
||||||
print(torchscript.compile(AddmmModule(), example_args,
|
print(
|
||||||
output_type="torch", backend_legal_ops=["aten.addmm"]))
|
torchscript.compile(
|
||||||
|
AddmmModule(),
|
||||||
|
example_args,
|
||||||
|
output_type="torch",
|
||||||
|
backend_legal_ops=["aten.addmm"],
|
||||||
|
)
|
||||||
|
)
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.addmm
|
# CHECK: torch.aten.addmm
|
||||||
|
|
|
@ -9,12 +9,15 @@ import torch
|
||||||
|
|
||||||
from torch_mlir import torchscript
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
class TanhModule(torch.nn.Module):
|
class TanhModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.tanh(x)
|
return torch.ops.aten.tanh(x)
|
||||||
|
|
||||||
|
|
||||||
tanh_example_input = torch.ones(2, 3)
|
tanh_example_input = torch.ones(2, 3)
|
||||||
|
|
||||||
# Simplest case: One example argument.
|
# Simplest case: One example argument.
|
||||||
|
@ -35,16 +38,23 @@ print(torchscript.compile(TanhModule(), placeholder))
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
||||||
|
|
||||||
# Basic smoke test for the raw output type.
|
# Basic smoke test for the raw output type.
|
||||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW))
|
print(
|
||||||
|
torchscript.compile(
|
||||||
|
TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW
|
||||||
|
)
|
||||||
|
)
|
||||||
# CHECK: torch.nn_module {
|
# CHECK: torch.nn_module {
|
||||||
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
|
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
|
||||||
|
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
def forward(self, lhs, rhs ):
|
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
return torch.ops.aten.mm(lhs, rhs)
|
return torch.ops.aten.mm(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
# N > 1 inputs.
|
# N > 1 inputs.
|
||||||
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
|
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
|
||||||
print(torchscript.compile(MmModule(), mm_example_inputs))
|
print(torchscript.compile(MmModule(), mm_example_inputs))
|
||||||
|
@ -52,7 +62,10 @@ print(torchscript.compile(MmModule(), mm_example_inputs))
|
||||||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
||||||
|
|
||||||
# Mixes Tensor's and TensorPlaceholder's.
|
# Mixes Tensor's and TensorPlaceholder's.
|
||||||
mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
|
mm_dynamic_inputs = [
|
||||||
|
mm_example_inputs[0],
|
||||||
|
torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1]),
|
||||||
|
]
|
||||||
print(torchscript.compile(MmModule(), mm_dynamic_inputs))
|
print(torchscript.compile(MmModule(), mm_dynamic_inputs))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>
|
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||||
|
|
|
@ -10,11 +10,19 @@ import torch
|
||||||
|
|
||||||
from torch_mlir import torchscript
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
def simple(x):
|
def simple(x):
|
||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
example_input = torch.randn(1,)
|
|
||||||
graph = functorch.make_fx(simple)(torch.randn(1,))
|
example_input = torch.randn(
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
graph = functorch.make_fx(simple)(
|
||||||
|
torch.randn(
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Simplest case: One example argument.
|
# Simplest case: One example argument.
|
||||||
print(torchscript.compile(graph, example_input))
|
print(torchscript.compile(graph, example_input))
|
||||||
|
|
|
@ -9,15 +9,22 @@ import torch
|
||||||
|
|
||||||
from torch_mlir import torchscript
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
class TanhModule(torch.nn.Module):
|
class TanhModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.tanh(x)
|
return torch.ops.aten.tanh(x)
|
||||||
|
|
||||||
|
|
||||||
tanh_example_input = torch.ones(2, 3)
|
tanh_example_input = torch.ones(2, 3)
|
||||||
|
|
||||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH))
|
print(
|
||||||
|
torchscript.compile(
|
||||||
|
TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH
|
||||||
|
)
|
||||||
|
)
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch"))
|
print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch"))
|
||||||
|
|
|
@ -14,6 +14,7 @@ class TanhModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.tanh(x)
|
return torch.ops.aten.tanh(x)
|
||||||
|
|
||||||
|
|
||||||
tanh_example_input = torch.ones(2, 3)
|
tanh_example_input = torch.ones(2, 3)
|
||||||
|
|
||||||
# Simplest case: One example argument.
|
# Simplest case: One example argument.
|
||||||
|
@ -32,10 +33,12 @@ print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True))
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
|
|
||||||
# TensorPlaceholder support.
|
# TensorPlaceholder support.
|
||||||
placeholder = torchscript.TensorPlaceholder.like(
|
placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
|
||||||
tanh_example_input, dynamic_axes=[1])
|
print(
|
||||||
print(torchscript.compile(TanhModule(), [placeholder],
|
torchscript.compile(
|
||||||
use_tracing=True, ignore_traced_shapes=True))
|
TanhModule(), [placeholder], use_tracing=True, ignore_traced_shapes=True
|
||||||
|
)
|
||||||
|
)
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||||
|
|
||||||
|
@ -55,18 +58,18 @@ except Exception as e:
|
||||||
|
|
||||||
class DictModule(torch.nn.Module):
|
class DictModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x['a'] * 2.0
|
return x["a"] * 2.0
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||||
torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
torchscript.compile(DictModule(), {"a": torch.tensor(3.0)}, use_tracing=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||||
torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
torchscript.compile(DictModule(), [{"a": torch.tensor(3.0)}], use_tracing=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
|
@ -15,8 +15,9 @@ from torch_mlir_e2e_test.debug.lockstep import make_lockstep_debug_backend
|
||||||
|
|
||||||
@make_simple_dynamo_backend
|
@make_simple_dynamo_backend
|
||||||
@make_lockstep_debug_backend()
|
@make_lockstep_debug_backend()
|
||||||
def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule,
|
def miscompile_div_as_mul_backend(
|
||||||
example_inputs: List[torch.Tensor]):
|
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||||
|
):
|
||||||
# Copy `gm` and rewrite `div` to `mul`.
|
# Copy `gm` and rewrite `div` to `mul`.
|
||||||
new_g = torch.fx.Graph()
|
new_g = torch.fx.Graph()
|
||||||
new_g.output(new_g.graph_copy(gm.graph, {}))
|
new_g.output(new_g.graph_copy(gm.graph, {}))
|
||||||
|
@ -41,7 +42,7 @@ def f(x, y):
|
||||||
return a, b, c
|
return a, b, c
|
||||||
|
|
||||||
|
|
||||||
args = (torch.tensor([1., 2., 3.]), torch.tensor([4., 5., 6.]))
|
args = (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]))
|
||||||
try:
|
try:
|
||||||
print(f(*args))
|
print(f(*args))
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
|
|
|
@ -11,7 +11,11 @@ import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
import torch._dynamo as dynamo
|
import torch._dynamo as dynamo
|
||||||
from torch._dynamo.backends.common import aot_autograd
|
from torch._dynamo.backends.common import aot_autograd
|
||||||
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_compilation_context, set_model_name
|
from torch._functorch.aot_autograd import (
|
||||||
|
make_boxed_compiler,
|
||||||
|
get_aot_compilation_context,
|
||||||
|
set_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
from torch_mlir.compiler_utils import TorchMlirCompilerError
|
from torch_mlir.compiler_utils import TorchMlirCompilerError
|
||||||
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
||||||
|
@ -19,8 +23,9 @@ from torch_mlir_e2e_test.configs.torchdynamo import jit
|
||||||
|
|
||||||
|
|
||||||
@make_boxed_compiler
|
@make_boxed_compiler
|
||||||
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
|
def my_aot_autograd_backend(
|
||||||
example_inputs: List[torch.Tensor]):
|
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||||
|
):
|
||||||
print(gm.graph)
|
print(gm.graph)
|
||||||
*_, model_name, nth_graph = get_aot_compilation_context()
|
*_, model_name, nth_graph = get_aot_compilation_context()
|
||||||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||||
|
@ -59,9 +64,7 @@ basic(torch.randn(3, 4))
|
||||||
# CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16>
|
# CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16>
|
||||||
@dynamo.optimize(my_backend)
|
@dynamo.optimize(my_backend)
|
||||||
def literals_list_device_int_none_dtype():
|
def literals_list_device_int_none_dtype():
|
||||||
return torch.ops.aten.randn([3, 4],
|
return torch.ops.aten.randn([3, 4], device=torch.device("cpu"), dtype=torch.float16)
|
||||||
device=torch.device("cpu"),
|
|
||||||
dtype=torch.float16)
|
|
||||||
|
|
||||||
|
|
||||||
set_model_name("literals_list_device_int_none_dtype")
|
set_model_name("literals_list_device_int_none_dtype")
|
||||||
|
|
|
@ -19,23 +19,23 @@ from lit.llvm.subst import FindTool
|
||||||
# Configuration file for the 'lit' test runner.
|
# Configuration file for the 'lit' test runner.
|
||||||
|
|
||||||
# name: The name of this test suite.
|
# name: The name of this test suite.
|
||||||
config.name = 'TORCH_MLIR_PYTHON'
|
config.name = "TORCH_MLIR_PYTHON"
|
||||||
|
|
||||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||||
if 'TEST_SRC_PATH' in os.environ:
|
if "TEST_SRC_PATH" in os.environ:
|
||||||
config.environment['TEST_SRC_PATH'] = os.environ['TEST_SRC_PATH']
|
config.environment["TEST_SRC_PATH"] = os.environ["TEST_SRC_PATH"]
|
||||||
|
|
||||||
# path to our python operation library
|
# path to our python operation library
|
||||||
config.environment['TEST_BUILD_PATH'] = os.path.join(config.torch_mlir_obj_root)
|
config.environment["TEST_BUILD_PATH"] = os.path.join(config.torch_mlir_obj_root)
|
||||||
|
|
||||||
# suffixes: A list of file extensions to treat as test files.
|
# suffixes: A list of file extensions to treat as test files.
|
||||||
config.suffixes = ['.py']
|
config.suffixes = [".py"]
|
||||||
|
|
||||||
# test_source_root: The root path where tests are located.
|
# test_source_root: The root path where tests are located.
|
||||||
config.test_source_root = os.path.dirname(__file__)
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
# test_exec_root: The root path where tests should be run.
|
# test_exec_root: The root path where tests should be run.
|
||||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||||
|
|
||||||
# On Windows the path to python could contains spaces in which case it needs to
|
# On Windows the path to python could contains spaces in which case it needs to
|
||||||
# be provided in quotes. This is the equivalent of how %python is setup in
|
# be provided in quotes. This is the equivalent of how %python is setup in
|
||||||
|
@ -43,19 +43,25 @@ config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||||
if "Windows" in config.host_os:
|
if "Windows" in config.host_os:
|
||||||
config.python_executable = '"%s"' % (config.python_executable)
|
config.python_executable = '"%s"' % (config.python_executable)
|
||||||
|
|
||||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
config.substitutions.append(("%PATH%", config.environment["PATH"]))
|
||||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
|
||||||
config.substitutions.append(('%PYTHON', config.python_executable))
|
config.substitutions.append(("%PYTHON", config.python_executable))
|
||||||
|
|
||||||
llvm_config.with_system_environment(
|
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
|
||||||
['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
|
||||||
|
|
||||||
llvm_config.use_default_substitutions()
|
llvm_config.use_default_substitutions()
|
||||||
|
|
||||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||||
# directories.
|
# directories.
|
||||||
config.excludes = ['lit.cfg.py', 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
|
config.excludes = [
|
||||||
|
"lit.cfg.py",
|
||||||
|
"Inputs",
|
||||||
|
"Examples",
|
||||||
|
"CMakeLists.txt",
|
||||||
|
"README.txt",
|
||||||
|
"LICENSE.txt",
|
||||||
|
]
|
||||||
|
|
||||||
if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))):
|
if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))):
|
||||||
config.excludes.append("lazy_backend")
|
config.excludes.append("lazy_backend")
|
||||||
|
@ -64,20 +70,23 @@ if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))):
|
||||||
config.test_source_root = os.path.dirname(__file__)
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
# test_exec_root: The root path where tests should be run.
|
# test_exec_root: The root path where tests should be run.
|
||||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||||
config.torch_mlir_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
config.torch_mlir_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin")
|
||||||
|
|
||||||
# Tweak the PATH to include the tools dir.
|
# Tweak the PATH to include the tools dir.
|
||||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
|
||||||
llvm_config.with_environment('PYTHONPATH', [
|
llvm_config.with_environment(
|
||||||
os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'),
|
"PYTHONPATH",
|
||||||
|
[
|
||||||
|
os.path.join(config.torch_mlir_python_packages_dir, "torch_mlir"),
|
||||||
],
|
],
|
||||||
append_path=True)
|
append_path=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
tool_dirs = [config.torch_mlir_tools_dir, config.llvm_tools_dir]
|
tool_dirs = [config.torch_mlir_tools_dir, config.llvm_tools_dir]
|
||||||
tools = [
|
tools = [
|
||||||
'torch-mlir-opt',
|
"torch-mlir-opt",
|
||||||
]
|
]
|
||||||
|
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
|
|
@ -40,5 +40,5 @@ def main():
|
||||||
report_results(results, set(), verbose=True)
|
report_results(results, set(), verbose=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -43,5 +43,5 @@ def main():
|
||||||
report_results(results, set(), verbose=True, config="myconfig")
|
report_results(results, set(), verbose=True, config="myconfig")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -122,7 +122,7 @@ class ErroneousModule(torch.nn.Module):
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def test_tensor_value_mismatch(self):
|
def test_tensor_value_mismatch(self):
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
return torch.tensor([1., 2., 3.])
|
return torch.tensor([1.0, 2.0, 3.0])
|
||||||
else:
|
else:
|
||||||
return torch.tensor([1.5, 2.5, 3.5])
|
return torch.tensor([1.5, 2.5, 3.5])
|
||||||
|
|
||||||
|
@ -132,9 +132,9 @@ class ErroneousModule(torch.nn.Module):
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def test_tensor_shape_mismatch(self):
|
def test_tensor_shape_mismatch(self):
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
return torch.tensor([1., 2.])
|
return torch.tensor([1.0, 2.0])
|
||||||
else:
|
else:
|
||||||
return torch.tensor([1., 2., 3.])
|
return torch.tensor([1.0, 2.0, 3.0])
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ErroneousModule())
|
@register_test_case(module_factory=lambda: ErroneousModule())
|
||||||
|
@ -157,5 +157,5 @@ def main():
|
||||||
report_results(results, set(), verbose=True)
|
report_results(results, set(), verbose=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -51,5 +51,5 @@ def main():
|
||||||
report_results(results, set())
|
report_results(results, set())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -39,5 +39,5 @@ def main():
|
||||||
report_results(results, set(), verbose=True)
|
report_results(results, set(), verbose=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir_e2e_test.reporting import report_results
|
||||||
from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
from torch_mlir_e2e_test.configs import TorchScriptTestConfig
|
from torch_mlir_e2e_test.configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
|
||||||
class Submodule2(torch.nn.Module):
|
class Submodule2(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -19,6 +20,7 @@ class Submodule2(torch.nn.Module):
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.mm(lhs, rhs)
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
class Submodule(torch.nn.Module):
|
class Submodule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -43,5 +45,5 @@ def main():
|
||||||
report_results(results, set())
|
report_results(results, set())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
# This file implements a pure-Python importer from a restricted subset of
|
# This file implements a pure-Python importer from a restricted subset of
|
||||||
# FX IR into MLIR.
|
# FX IR into MLIR.
|
||||||
#
|
#
|
||||||
|
@ -77,8 +78,12 @@ def _verify_fx_graph_conforms_to_subset(g: torch.fx.Graph):
|
||||||
if len(node.args) != len(node.target._schema.arguments):
|
if len(node.args) != len(node.target._schema.arguments):
|
||||||
assert len(node.args) < len(node.target._schema.arguments)
|
assert len(node.args) < len(node.target._schema.arguments)
|
||||||
for i, argument in enumerate(
|
for i, argument in enumerate(
|
||||||
node.target._schema.arguments[len(node.args):]):
|
node.target._schema.arguments[len(node.args) :]
|
||||||
if not argument.has_default_value() and argument.name not in node.kwargs:
|
):
|
||||||
|
if (
|
||||||
|
not argument.has_default_value()
|
||||||
|
and argument.name not in node.kwargs
|
||||||
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Unsupported: missing default value for argument {i} in schema for {node.target}"
|
f"Unsupported: missing default value for argument {i} in schema for {node.target}"
|
||||||
)
|
)
|
||||||
|
@ -152,12 +157,12 @@ def _convert_dtype_to_mlir_type(dtype: torch.dtype) -> str:
|
||||||
if dtype == torch.complex128:
|
if dtype == torch.complex128:
|
||||||
return "complex<f64>"
|
return "complex<f64>"
|
||||||
|
|
||||||
|
|
||||||
raise Exception(f"Unsupported dtype: {dtype}")
|
raise Exception(f"Unsupported dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
def _import_fake_tensor_as_mlir_type(
|
def _import_fake_tensor_as_mlir_type(
|
||||||
fake_tensor: torch._subclasses.FakeTensor) -> ir.Type:
|
fake_tensor: torch._subclasses.FakeTensor,
|
||||||
|
) -> ir.Type:
|
||||||
# TODO: Find story for how to get dynamically shaped tensors here.
|
# TODO: Find story for how to get dynamically shaped tensors here.
|
||||||
shape = ",".join(str(d) for d in fake_tensor.shape)
|
shape = ",".join(str(d) for d in fake_tensor.shape)
|
||||||
dtype = _convert_dtype_to_mlir_type(fake_tensor.dtype)
|
dtype = _convert_dtype_to_mlir_type(fake_tensor.dtype)
|
||||||
|
@ -178,7 +183,8 @@ def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType:
|
||||||
if node.op == "output":
|
if node.op == "output":
|
||||||
# TODO(DNS): Test this or add verifier that it can't happen.
|
# TODO(DNS): Test this or add verifier that it can't happen.
|
||||||
result_types = torch.fx.map_arg(
|
result_types = torch.fx.map_arg(
|
||||||
node.args[0], lambda n: _mlir_types_for_node(n)[0])
|
node.args[0], lambda n: _mlir_types_for_node(n)[0]
|
||||||
|
)
|
||||||
# Note: We import directly to the backend contract -- multiple results
|
# Note: We import directly to the backend contract -- multiple results
|
||||||
# are modeled with func.func native multiple results rather than as a
|
# are modeled with func.func native multiple results rather than as a
|
||||||
# singleton value / tuple.
|
# singleton value / tuple.
|
||||||
|
@ -191,64 +197,40 @@ def _extract_function_type_from_graph(g: torch.fx.Graph) -> ir.FunctionType:
|
||||||
|
|
||||||
DTYPE_TO_INT = {
|
DTYPE_TO_INT = {
|
||||||
# TODO(DNS): Fill in from AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
|
# TODO(DNS): Fill in from AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
|
||||||
torch.uint8:
|
torch.uint8: 0,
|
||||||
0,
|
torch.int8: 1,
|
||||||
torch.int8:
|
torch.int16: 2,
|
||||||
1,
|
torch.int32: 3,
|
||||||
torch.int16:
|
torch.int64: 4,
|
||||||
2,
|
torch.float16: 5,
|
||||||
torch.int32:
|
torch.float32: 6,
|
||||||
3,
|
torch.float64: 7,
|
||||||
torch.int64:
|
|
||||||
4,
|
|
||||||
torch.float16:
|
|
||||||
5,
|
|
||||||
torch.float32:
|
|
||||||
6,
|
|
||||||
torch.float64:
|
|
||||||
7,
|
|
||||||
# torch.complex_half 8
|
# torch.complex_half 8
|
||||||
torch.complex64:
|
torch.complex64: 9,
|
||||||
9,
|
torch.complex128: 10,
|
||||||
torch.complex128:
|
torch.bool: 11,
|
||||||
10,
|
torch.qint8: 12,
|
||||||
torch.bool:
|
torch.quint8: 13,
|
||||||
11,
|
|
||||||
torch.qint8:
|
|
||||||
12,
|
|
||||||
torch.quint8:
|
|
||||||
13,
|
|
||||||
# torch.qint32 14
|
# torch.qint32 14
|
||||||
torch.bfloat16:
|
torch.bfloat16: 15,
|
||||||
15,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MEMORY_FORMAT_TO_INT = {
|
MEMORY_FORMAT_TO_INT = {
|
||||||
# https://github.com/pytorch/pytorch/c10/core/MemoryFormat.h#L28
|
# https://github.com/pytorch/pytorch/c10/core/MemoryFormat.h#L28
|
||||||
torch.contiguous_format:
|
torch.contiguous_format: 0,
|
||||||
0,
|
torch.preserve_format: 1,
|
||||||
torch.preserve_format:
|
torch.channels_last: 2,
|
||||||
1,
|
torch.channels_last_3d: 3,
|
||||||
torch.channels_last:
|
|
||||||
2,
|
|
||||||
torch.channels_last_3d:
|
|
||||||
3,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LAYOUT_TO_INT = {
|
LAYOUT_TO_INT = {
|
||||||
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_layouts.cpp
|
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_layouts.cpp
|
||||||
torch.strided:
|
torch.strided: 0,
|
||||||
0,
|
torch.sparse_coo: 1,
|
||||||
torch.sparse_coo:
|
torch.sparse_csr: 2,
|
||||||
1,
|
torch.sparse_csc: 3,
|
||||||
torch.sparse_csr:
|
torch.sparse_bsr: 4,
|
||||||
2,
|
torch.sparse_bsc: 5,
|
||||||
torch.sparse_csc:
|
|
||||||
3,
|
|
||||||
torch.sparse_bsr:
|
|
||||||
4,
|
|
||||||
torch.sparse_bsc:
|
|
||||||
5,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -264,7 +246,6 @@ def _mlir_location_for_node(node: torch.fx.Node) -> ir.Location:
|
||||||
|
|
||||||
|
|
||||||
class _FXGraphImporter:
|
class _FXGraphImporter:
|
||||||
|
|
||||||
def __init__(self, g: torch.fx.Graph, func_name: str):
|
def __init__(self, g: torch.fx.Graph, func_name: str):
|
||||||
self._g = g
|
self._g = g
|
||||||
self._func_name = func_name
|
self._func_name = func_name
|
||||||
|
@ -277,7 +258,8 @@ class _FXGraphImporter:
|
||||||
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
||||||
self._module = ir.Module.create(ir.Location.unknown())
|
self._module = ir.Module.create(ir.Location.unknown())
|
||||||
self._module.operation.attributes[
|
self._module.operation.attributes[
|
||||||
"torch.debug_module_name"] = ir.StringAttr.get(func_name)
|
"torch.debug_module_name"
|
||||||
|
] = ir.StringAttr.get(func_name)
|
||||||
function_type = _extract_function_type_from_graph(g)
|
function_type = _extract_function_type_from_graph(g)
|
||||||
func = func_dialect.FuncOp(
|
func = func_dialect.FuncOp(
|
||||||
func_name,
|
func_name,
|
||||||
|
@ -285,8 +267,7 @@ class _FXGraphImporter:
|
||||||
loc=ir.Location.unknown(), # TODO: Can we do better?
|
loc=ir.Location.unknown(), # TODO: Can we do better?
|
||||||
ip=ir.InsertionPoint(self._module.body),
|
ip=ir.InsertionPoint(self._module.body),
|
||||||
)
|
)
|
||||||
self._body_block = ir.Block.create_at_start(func.body,
|
self._body_block = ir.Block.create_at_start(func.body, function_type.inputs)
|
||||||
function_type.inputs)
|
|
||||||
|
|
||||||
def import_graph(self) -> ir.Module:
|
def import_graph(self) -> ir.Module:
|
||||||
with ir.InsertionPoint(self._body_block):
|
with ir.InsertionPoint(self._body_block):
|
||||||
|
@ -294,14 +275,15 @@ class _FXGraphImporter:
|
||||||
for node in self._g.nodes:
|
for node in self._g.nodes:
|
||||||
with _mlir_location_for_node(node):
|
with _mlir_location_for_node(node):
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
self._env[(
|
self._env[(node, 0)] = self._body_block.arguments[
|
||||||
node, 0
|
num_placeholders_seen
|
||||||
)] = self._body_block.arguments[num_placeholders_seen]
|
]
|
||||||
num_placeholders_seen += 1
|
num_placeholders_seen += 1
|
||||||
if node.op == "call_function":
|
if node.op == "call_function":
|
||||||
if node.target is operator.getitem:
|
if node.target is operator.getitem:
|
||||||
self._env[(node, 0)] = self._env[(node.args[0],
|
self._env[(node, 0)] = self._env[
|
||||||
node.args[1])]
|
(node.args[0], node.args[1])
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self._import_op_overload_call(node)
|
self._import_op_overload_call(node)
|
||||||
if node.op == "output":
|
if node.op == "output":
|
||||||
|
@ -309,9 +291,7 @@ class _FXGraphImporter:
|
||||||
# a tuple of return values (without the single-element special
|
# a tuple of return values (without the single-element special
|
||||||
# case)
|
# case)
|
||||||
# DNS: Test or verify no literals as results.
|
# DNS: Test or verify no literals as results.
|
||||||
operands = [
|
operands = [self._import_argument(arg) for arg in node.args[0]]
|
||||||
self._import_argument(arg) for arg in node.args[0]
|
|
||||||
]
|
|
||||||
func_dialect.ReturnOp(operands)
|
func_dialect.ReturnOp(operands)
|
||||||
return self._module
|
return self._module
|
||||||
|
|
||||||
|
@ -328,7 +308,8 @@ class _FXGraphImporter:
|
||||||
|
|
||||||
# DNS: Unregistered ops
|
# DNS: Unregistered ops
|
||||||
assert ir.Context.current.is_registered_operation(
|
assert ir.Context.current.is_registered_operation(
|
||||||
mlir_op_name), f"Unregistered operation: {mlir_op_name}"
|
mlir_op_name
|
||||||
|
), f"Unregistered operation: {mlir_op_name}"
|
||||||
|
|
||||||
# Construct the Operation.
|
# Construct the Operation.
|
||||||
result_types = _mlir_types_for_node(node)
|
result_types = _mlir_types_for_node(node)
|
||||||
|
@ -352,9 +333,9 @@ class _FXGraphImporter:
|
||||||
for i, value in enumerate(operation.results):
|
for i, value in enumerate(operation.results):
|
||||||
self._env[(node, i)] = value
|
self._env[(node, i)] = value
|
||||||
|
|
||||||
def _import_argument(self,
|
def _import_argument(
|
||||||
arg: torch.fx.node.Argument,
|
self, arg: torch.fx.node.Argument, expected_type_for_literal=None
|
||||||
expected_type_for_literal=None) -> ir.Value:
|
) -> ir.Value:
|
||||||
"""Import an FX `Argument`, which is analogous to an MLIR `Value`.
|
"""Import an FX `Argument`, which is analogous to an MLIR `Value`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -371,22 +352,21 @@ class _FXGraphImporter:
|
||||||
assert expected_type_for_literal is not None
|
assert expected_type_for_literal is not None
|
||||||
return self._import_literal(arg, expected_type_for_literal)
|
return self._import_literal(arg, expected_type_for_literal)
|
||||||
|
|
||||||
def _import_literal(self, arg: torch.fx.node.Argument,
|
def _import_literal(self, arg: torch.fx.node.Argument, expected_type) -> ir.Value:
|
||||||
expected_type) -> ir.Value:
|
|
||||||
if arg is None:
|
if arg is None:
|
||||||
return torch_dialect.ConstantNoneOp().result
|
return torch_dialect.ConstantNoneOp().result
|
||||||
if isinstance(expected_type, torch.OptionalType):
|
if isinstance(expected_type, torch.OptionalType):
|
||||||
return self._import_argument(arg, expected_type.getElementType())
|
return self._import_argument(arg, expected_type.getElementType())
|
||||||
if isinstance(arg, bool):
|
if isinstance(arg, bool):
|
||||||
return torch_dialect.ConstantBoolOp(
|
return torch_dialect.ConstantBoolOp(
|
||||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)).result
|
ir.IntegerAttr.get(ir.IntegerType.get_signless(1), arg)
|
||||||
|
).result
|
||||||
if isinstance(arg, int):
|
if isinstance(arg, int):
|
||||||
return torch_dialect.ConstantIntOp(
|
return torch_dialect.ConstantIntOp(
|
||||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(64),
|
ir.IntegerAttr.get(ir.IntegerType.get_signless(64), arg)
|
||||||
arg)).result
|
).result
|
||||||
if isinstance(arg, float):
|
if isinstance(arg, float):
|
||||||
return torch_dialect.ConstantFloatOp(
|
return torch_dialect.ConstantFloatOp(ir.FloatAttr.get_f64(arg)).result
|
||||||
ir.FloatAttr.get_f64(arg)).result
|
|
||||||
if isinstance(arg, str):
|
if isinstance(arg, str):
|
||||||
return torch_dialect.ConstantStrOp(ir.StringAttr.get(arg)).result
|
return torch_dialect.ConstantStrOp(ir.StringAttr.get(arg)).result
|
||||||
if isinstance(arg, torch.dtype):
|
if isinstance(arg, torch.dtype):
|
||||||
|
@ -394,12 +374,10 @@ class _FXGraphImporter:
|
||||||
return self._import_argument(DTYPE_TO_INT[arg], expected_type)
|
return self._import_argument(DTYPE_TO_INT[arg], expected_type)
|
||||||
if isinstance(arg, torch.device):
|
if isinstance(arg, torch.device):
|
||||||
# TODO(DNS): Device index? arg.index
|
# TODO(DNS): Device index? arg.index
|
||||||
return torch_dialect.ConstantDeviceOp(ir.StringAttr.get(
|
return torch_dialect.ConstantDeviceOp(ir.StringAttr.get(arg.type)).result
|
||||||
arg.type)).result
|
|
||||||
if isinstance(arg, torch.memory_format):
|
if isinstance(arg, torch.memory_format):
|
||||||
assert isinstance(expected_type, torch.IntType)
|
assert isinstance(expected_type, torch.IntType)
|
||||||
return self._import_argument(MEMORY_FORMAT_TO_INT[arg],
|
return self._import_argument(MEMORY_FORMAT_TO_INT[arg], expected_type)
|
||||||
expected_type)
|
|
||||||
if isinstance(arg, torch.layout):
|
if isinstance(arg, torch.layout):
|
||||||
assert isinstance(expected_type, torch.IntType)
|
assert isinstance(expected_type, torch.IntType)
|
||||||
return self._import_argument(LAYOUT_TO_INT[arg], expected_type)
|
return self._import_argument(LAYOUT_TO_INT[arg], expected_type)
|
||||||
|
@ -409,14 +387,14 @@ class _FXGraphImporter:
|
||||||
if isinstance(element_type, torch.TensorType):
|
if isinstance(element_type, torch.TensorType):
|
||||||
assert all(
|
assert all(
|
||||||
torch.fx.node.map_aggregate(
|
torch.fx.node.map_aggregate(
|
||||||
arg, lambda a: _is_valid_meta_val(a.meta.get("val"))))
|
arg, lambda a: _is_valid_meta_val(a.meta.get("val"))
|
||||||
|
)
|
||||||
|
)
|
||||||
els = [self._env[e, 0] for e in arg]
|
els = [self._env[e, 0] for e in arg]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
element_type = _torch_type_to_mlir_type(element_type)
|
element_type = _torch_type_to_mlir_type(element_type)
|
||||||
els = [
|
els = [self._import_argument(e, element_type) for e in arg]
|
||||||
self._import_argument(e, element_type) for e in arg
|
|
||||||
]
|
|
||||||
|
|
||||||
# import pydevd_pycharm
|
# import pydevd_pycharm
|
||||||
# pydevd_pycharm.settrace('localhost', port=8888, stdoutToServer=True, stderrToServer=True)
|
# pydevd_pycharm.settrace('localhost', port=8888, stdoutToServer=True, stderrToServer=True)
|
||||||
|
|
|
@ -3,6 +3,5 @@ import torch
|
||||||
|
|
||||||
# Register _torch_mlir_custom_op_example.identity as a side-effect of importing.
|
# Register _torch_mlir_custom_op_example.identity as a side-effect of importing.
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
lib = os.path.join(*[current_dir, 'libtorch_mlir_custom_op_example.so'])
|
lib = os.path.join(*[current_dir, "libtorch_mlir_custom_op_example.so"])
|
||||||
torch.ops.load_library(lib)
|
torch.ops.load_library(lib)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
from packaging import version
|
from packaging import version
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def torch_version_for_comparison():
|
def torch_version_for_comparison():
|
||||||
# Ignore +cpu, +cu117m, etc. in comparisons
|
# Ignore +cpu, +cu117m, etc. in comparisons
|
||||||
return version.parse(torch.__version__.split("+", 1)[0])
|
return version.parse(torch.__version__.split("+", 1)[0])
|
||||||
|
|
|
@ -4,20 +4,20 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
path = sys.argv[1] # dummy script path
|
path = sys.argv[1] # dummy script path
|
||||||
file_name = sys.argv[2] # dummy script
|
file_name = sys.argv[2] # dummy script
|
||||||
|
|
||||||
contents = '''
|
contents = """
|
||||||
# This file was automatically generated due to LTC being disabled in build.
|
# This file was automatically generated due to LTC being disabled in build.
|
||||||
|
|
||||||
class LazyTensorCoreTestConfig:
|
class LazyTensorCoreTestConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
assert False, "LTC is not enabled. Check the value of `TORCH_MLIR_ENABLE_LTC`"
|
assert False, "LTC is not enabled. Check the value of `TORCH_MLIR_ENABLE_LTC`"
|
||||||
'''
|
"""
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
with open(os.path.join(path, file_name + '.py'), 'w') as file:
|
with open(os.path.join(path, file_name + ".py"), "w") as file:
|
||||||
file.write(contents)
|
file.write(contents)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch._dynamo.backends.common import aot_autograd
|
||||||
import functorch
|
import functorch
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/89064
|
# https://github.com/pytorch/pytorch/issues/89064
|
||||||
warnings.filterwarnings("ignore", module="torch.jit._check")
|
warnings.filterwarnings("ignore", module="torch.jit._check")
|
||||||
|
|
||||||
|
@ -91,8 +92,7 @@ def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
|
||||||
did_convert_list_to_tuple = False
|
did_convert_list_to_tuple = False
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == "output":
|
if node.op == "output":
|
||||||
assert len(node.args) == 1, \
|
assert len(node.args) == 1, "Output node must have a single argument"
|
||||||
"Output node must have a single argument"
|
|
||||||
node_arg = node.args[0]
|
node_arg = node.args[0]
|
||||||
if isinstance(node_arg, tuple):
|
if isinstance(node_arg, tuple):
|
||||||
if len(node_arg) == 1:
|
if len(node_arg) == 1:
|
||||||
|
@ -106,7 +106,7 @@ def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
|
||||||
did_convert_list_to_tuple = True
|
did_convert_list_to_tuple = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
node.args= (tuple(node_arg),)
|
node.args = (tuple(node_arg),)
|
||||||
did_convert_list_to_tuple = True
|
did_convert_list_to_tuple = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -129,10 +129,12 @@ def make_simple_dynamo_backend(user_backend):
|
||||||
Returns:
|
Returns:
|
||||||
A function with the signature used by TorchDynamo backends.
|
A function with the signature used by TorchDynamo backends.
|
||||||
"""
|
"""
|
||||||
def wrapper_backend(gm: torch.fx.GraphModule,
|
|
||||||
example_inputs: List[torch.Tensor]):
|
def wrapper_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||||
did_unwrap_single_element, did_convert_list_to_tuple = \
|
(
|
||||||
_adjust_calling_convention(gm)
|
did_unwrap_single_element,
|
||||||
|
did_convert_list_to_tuple,
|
||||||
|
) = _adjust_calling_convention(gm)
|
||||||
strip_overloads(gm)
|
strip_overloads(gm)
|
||||||
user_callable = user_backend(gm, example_inputs)
|
user_callable = user_backend(gm, example_inputs)
|
||||||
|
|
||||||
|
@ -147,6 +149,9 @@ def make_simple_dynamo_backend(user_backend):
|
||||||
if did_convert_list_to_tuple:
|
if did_convert_list_to_tuple:
|
||||||
result = list(result)
|
result = list(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return dynamo_callable
|
return dynamo_callable
|
||||||
return aot_autograd(fw_compiler=wrapper_backend,
|
|
||||||
decompositions=_get_decomposition_table)
|
return aot_autograd(
|
||||||
|
fw_compiler=wrapper_backend, decompositions=_get_decomposition_table
|
||||||
|
)
|
||||||
|
|
|
@ -15,24 +15,31 @@ from torch_mlir.passmanager import PassManager
|
||||||
|
|
||||||
from .registry import Registry
|
from .registry import Registry
|
||||||
|
|
||||||
|
|
||||||
def all_integer_dtypes() -> List[int]:
|
def all_integer_dtypes() -> List[int]:
|
||||||
return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||||
|
|
||||||
|
|
||||||
def is_integer_dtype(dtype: int) -> bool:
|
def is_integer_dtype(dtype: int) -> bool:
|
||||||
return dtype in all_integer_dtypes()
|
return dtype in all_integer_dtypes()
|
||||||
|
|
||||||
|
|
||||||
def all_complex_dtypes() -> List[int]:
|
def all_complex_dtypes() -> List[int]:
|
||||||
return [torch.complex64, torch.complex128]
|
return [torch.complex64, torch.complex128]
|
||||||
|
|
||||||
|
|
||||||
def is_complex_dtype(dtype: int) -> bool:
|
def is_complex_dtype(dtype: int) -> bool:
|
||||||
return dtype in all_complex_dtypes()
|
return dtype in all_complex_dtypes()
|
||||||
|
|
||||||
|
|
||||||
def all_float_dtypes() -> List[int]:
|
def all_float_dtypes() -> List[int]:
|
||||||
return [torch.float16, torch.bfloat16, torch.float32, torch.float64]
|
return [torch.float16, torch.bfloat16, torch.float32, torch.float64]
|
||||||
|
|
||||||
|
|
||||||
def is_float_dtype(dtype: int) -> bool:
|
def is_float_dtype(dtype: int) -> bool:
|
||||||
return dtype in all_float_dtypes()
|
return dtype in all_float_dtypes()
|
||||||
|
|
||||||
|
|
||||||
def get_priority_of_dtype(dtype: int) -> int:
|
def get_priority_of_dtype(dtype: int) -> int:
|
||||||
# If a loop is used to iterate over a list of sorted dtypes, TorchScript
|
# If a loop is used to iterate over a list of sorted dtypes, TorchScript
|
||||||
# produces a loop with INT64_MAX max trip count, which causes problems
|
# produces a loop with INT64_MAX max trip count, which causes problems
|
||||||
|
@ -64,6 +71,7 @@ def get_priority_of_dtype(dtype: int) -> int:
|
||||||
return 11
|
return 11
|
||||||
assert False, "Cannot determine priority of dtype"
|
assert False, "Cannot determine priority of dtype"
|
||||||
|
|
||||||
|
|
||||||
def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int:
|
def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int:
|
||||||
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
|
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
|
||||||
# that when `jit.script`ed converts a float scalar to a tensor
|
# that when `jit.script`ed converts a float scalar to a tensor
|
||||||
|
@ -83,6 +91,7 @@ def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int:
|
||||||
# op.
|
# op.
|
||||||
return torch.ops.prim.NumToTensor(scalar).dtype
|
return torch.ops.prim.NumToTensor(scalar).dtype
|
||||||
|
|
||||||
|
|
||||||
# When we import into torch-mlir, only the calls to
|
# When we import into torch-mlir, only the calls to
|
||||||
# `__torch_mlir_internal_promote_dtypes` are used to generate the
|
# `__torch_mlir_internal_promote_dtypes` are used to generate the
|
||||||
# `torch.promote_dtypes` ops. Therefore, to avoid generating extra
|
# `torch.promote_dtypes` ops. Therefore, to avoid generating extra
|
||||||
|
@ -97,23 +106,29 @@ def _get_scalar_with_dtype(dtype: torch.dtype) -> Union[int, float]:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unhandled dtype: {dtype}")
|
raise ValueError(f"Unhandled dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def _promote_scalar_tensor(scalar_dtype: torch.dtype, tensor_rank: int,
|
def _promote_scalar_tensor(
|
||||||
tensor_dtype: torch.dtype) -> torch.dtype:
|
scalar_dtype: torch.dtype, tensor_rank: int, tensor_dtype: torch.dtype
|
||||||
|
) -> torch.dtype:
|
||||||
scalar = _get_scalar_with_dtype(scalar_dtype)
|
scalar = _get_scalar_with_dtype(scalar_dtype)
|
||||||
tensor = torch.rand([1] * tensor_rank).to(tensor_dtype)
|
tensor = torch.rand([1] * tensor_rank).to(tensor_dtype)
|
||||||
return torch.result_type(scalar, tensor)
|
return torch.result_type(scalar, tensor)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def _promote_tensor_tensor(lhs_rank: int, lhs_dtype: torch.dtype,
|
def _promote_tensor_tensor(
|
||||||
rhs_rank: int, rhs_dtype: torch.dtype) -> torch.dtype:
|
lhs_rank: int, lhs_dtype: torch.dtype, rhs_rank: int, rhs_dtype: torch.dtype
|
||||||
|
) -> torch.dtype:
|
||||||
lhs_tensor = torch.rand([1] * lhs_rank).to(lhs_dtype)
|
lhs_tensor = torch.rand([1] * lhs_rank).to(lhs_dtype)
|
||||||
rhs_tensor = torch.rand([1] * rhs_rank).to(rhs_dtype)
|
rhs_tensor = torch.rand([1] * rhs_rank).to(rhs_dtype)
|
||||||
return torch.result_type(lhs_tensor, rhs_tensor)
|
return torch.result_type(lhs_tensor, rhs_tensor)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def _promote_scalar_scalar(lhs_dtype: torch.dtype,
|
def _promote_scalar_scalar(
|
||||||
rhs_dtype: torch.dtype) -> torch.dtype:
|
lhs_dtype: torch.dtype, rhs_dtype: torch.dtype
|
||||||
|
) -> torch.dtype:
|
||||||
# When `torch.result_type` is used on two scalars, the result
|
# When `torch.result_type` is used on two scalars, the result
|
||||||
# dtype is the dtype one would expect for an op with signature
|
# dtype is the dtype one would expect for an op with signature
|
||||||
# (Scalar, Scalar) -> (Tensor). However, once a module gets
|
# (Scalar, Scalar) -> (Tensor). However, once a module gets
|
||||||
|
@ -122,15 +137,17 @@ def _promote_scalar_scalar(lhs_dtype: torch.dtype,
|
||||||
# dtype, we use the tensor-tensor promotion rules.
|
# dtype, we use the tensor-tensor promotion rules.
|
||||||
return _promote_tensor_tensor(0, lhs_dtype, 0, rhs_dtype)
|
return _promote_tensor_tensor(0, lhs_dtype, 0, rhs_dtype)
|
||||||
|
|
||||||
def promote_dtypes(ranks: List[Optional[int]],
|
|
||||||
dtypes: List[torch.dtype]) -> torch.dtype:
|
def promote_dtypes(
|
||||||
"""Apply PyTorch dtype promotion rules and return the result type.
|
ranks: List[Optional[int]], dtypes: List[torch.dtype]
|
||||||
"""
|
) -> torch.dtype:
|
||||||
|
"""Apply PyTorch dtype promotion rules and return the result type."""
|
||||||
return __torch_mlir_internal_promote_dtypes(ranks, dtypes)
|
return __torch_mlir_internal_promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]],
|
|
||||||
dtypes: List[torch.dtype]
|
def __torch_mlir_internal_promote_dtypes(
|
||||||
) -> torch.dtype:
|
ranks: List[Optional[int]], dtypes: List[torch.dtype]
|
||||||
|
) -> torch.dtype:
|
||||||
"""Apply PyTorch dtype promotion rules and return the result type.
|
"""Apply PyTorch dtype promotion rules and return the result type.
|
||||||
|
|
||||||
This function serves two purposes:
|
This function serves two purposes:
|
||||||
|
@ -145,18 +162,18 @@ def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]],
|
||||||
if lhs_optional_rank is None and rhs_optional_rank is None:
|
if lhs_optional_rank is None and rhs_optional_rank is None:
|
||||||
lhs_dtype = _promote_scalar_scalar(lhs_dtype, rhs_dtype)
|
lhs_dtype = _promote_scalar_scalar(lhs_dtype, rhs_dtype)
|
||||||
elif lhs_optional_rank is None and rhs_optional_rank is not None:
|
elif lhs_optional_rank is None and rhs_optional_rank is not None:
|
||||||
lhs_dtype = _promote_scalar_tensor(
|
lhs_dtype = _promote_scalar_tensor(lhs_dtype, rhs_optional_rank, rhs_dtype)
|
||||||
lhs_dtype, rhs_optional_rank, rhs_dtype)
|
|
||||||
lhs_optional_rank = rhs_optional_rank
|
lhs_optional_rank = rhs_optional_rank
|
||||||
elif lhs_optional_rank is not None and rhs_optional_rank is None:
|
elif lhs_optional_rank is not None and rhs_optional_rank is None:
|
||||||
lhs_dtype = _promote_scalar_tensor(
|
lhs_dtype = _promote_scalar_tensor(rhs_dtype, lhs_optional_rank, lhs_dtype)
|
||||||
rhs_dtype, lhs_optional_rank, lhs_dtype)
|
|
||||||
elif lhs_optional_rank is not None and rhs_optional_rank is not None:
|
elif lhs_optional_rank is not None and rhs_optional_rank is not None:
|
||||||
lhs_dtype = _promote_tensor_tensor(
|
lhs_dtype = _promote_tensor_tensor(
|
||||||
lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype)
|
lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype
|
||||||
|
)
|
||||||
lhs_optional_rank = max(lhs_optional_rank, rhs_optional_rank)
|
lhs_optional_rank = max(lhs_optional_rank, rhs_optional_rank)
|
||||||
return lhs_dtype
|
return lhs_dtype
|
||||||
|
|
||||||
|
|
||||||
def not_present_in_registry(f):
|
def not_present_in_registry(f):
|
||||||
"""Decorator for abstract interpretation functions not present in the registry.
|
"""Decorator for abstract interpretation functions not present in the registry.
|
||||||
|
|
||||||
|
@ -175,6 +192,7 @@ def not_present_in_registry(f):
|
||||||
f._not_present_in_registry = None
|
f._not_present_in_registry = None
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def _verify_signature_matches_registry(f, registry: Registry):
|
def _verify_signature_matches_registry(f, registry: Registry):
|
||||||
source = inspect.getsource(f)
|
source = inspect.getsource(f)
|
||||||
signature = None
|
signature = None
|
||||||
|
@ -183,7 +201,9 @@ def _verify_signature_matches_registry(f, registry: Registry):
|
||||||
signature = line
|
signature = line
|
||||||
break
|
break
|
||||||
assert signature is not None, f"Could not find signature for {f.__name__}"
|
assert signature is not None, f"Could not find signature for {f.__name__}"
|
||||||
assert "〡" in signature, f"Malformed signature {signature}. Signature missing the character `〡`"
|
assert (
|
||||||
|
"〡" in signature
|
||||||
|
), f"Malformed signature {signature}. Signature missing the character `〡`"
|
||||||
function_name, function_kind = f.__name__.split("〡")
|
function_name, function_kind = f.__name__.split("〡")
|
||||||
atoms = function_name.split("〇")
|
atoms = function_name.split("〇")
|
||||||
if len(atoms) == 2:
|
if len(atoms) == 2:
|
||||||
|
@ -203,7 +223,10 @@ def _verify_signature_matches_registry(f, registry: Registry):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
|
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
|
||||||
if signature != expected_signature:
|
if signature != expected_signature:
|
||||||
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
|
raise ValueError(
|
||||||
|
f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_library(functions: Dict[str, Any]) -> str:
|
def generate_library(functions: Dict[str, Any]) -> str:
|
||||||
"""Convert all op functions in `functions` into MLIR."""
|
"""Convert all op functions in `functions` into MLIR."""
|
||||||
|
@ -245,11 +268,13 @@ def generate_library(functions: Dict[str, Any]) -> str:
|
||||||
# the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}`
|
# the format: `__torch__.{namespace_1}.{namespace_2}...{op_name}`
|
||||||
# The extra namespaces are not part of the abstract interpretation
|
# The extra namespaces are not part of the abstract interpretation
|
||||||
# function name, so here we simply drop the extra namespaces.
|
# function name, so here we simply drop the extra namespaces.
|
||||||
namespace = fr"(?:{name}\.)"
|
namespace = rf"(?:{name}\.)"
|
||||||
|
|
||||||
asm = re.sub(fr'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"',
|
asm = re.sub(
|
||||||
fr'@"__torch_mlir_\3_fn.\1{circle}\2"',
|
rf'@"__torch__\.{namespace}*({name}){circle}({name}){line}({name})"',
|
||||||
asm)
|
rf'@"__torch_mlir_\3_fn.\1{circle}\2"',
|
||||||
|
asm,
|
||||||
|
)
|
||||||
|
|
||||||
# Put the `〇` back to a regular `.`.
|
# Put the `〇` back to a regular `.`.
|
||||||
asm = asm.replace(codecs.decode(circle, "unicode_escape"), ".")
|
asm = asm.replace(codecs.decode(circle, "unicode_escape"), ".")
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
@ -15,13 +14,17 @@ import difflib
|
||||||
from .utils import TextEmitter
|
from .utils import TextEmitter
|
||||||
|
|
||||||
# Note that this utility exists only in the c-extension.
|
# Note that this utility exists only in the c-extension.
|
||||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
|
from torch_mlir._mlir_libs._jit_ir_importer import (
|
||||||
|
get_registered_ops,
|
||||||
|
) # pytype: disable=import-error
|
||||||
|
|
||||||
|
|
||||||
def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
|
def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
|
||||||
if parameter_name == "from":
|
if parameter_name == "from":
|
||||||
parameter_name = "from_" # Avoid using a Python keyword.
|
parameter_name = "from_" # Avoid using a Python keyword.
|
||||||
return parameter_name
|
return parameter_name
|
||||||
|
|
||||||
|
|
||||||
def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
default = ""
|
default = ""
|
||||||
if "default_debug" in arg:
|
if "default_debug" in arg:
|
||||||
|
@ -40,8 +43,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
if default_list == "[]":
|
if default_list == "[]":
|
||||||
default_debug = "()"
|
default_debug = "()"
|
||||||
else:
|
else:
|
||||||
default_debug = default_list.replace(
|
default_debug = default_list.replace("[", "(").replace("]", ",)")
|
||||||
"[", "(").replace("]", ",)")
|
|
||||||
elif arg["pytype"] == "str":
|
elif arg["pytype"] == "str":
|
||||||
default_debug = repr(arg["default_debug"]).replace("'", '"')
|
default_debug = repr(arg["default_debug"]).replace("'", '"')
|
||||||
else:
|
else:
|
||||||
|
@ -49,6 +51,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
default = f" = {default_debug}"
|
default = f" = {default_debug}"
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def _pytype_to_fn_pytype_common(pytype: str) -> str:
|
def _pytype_to_fn_pytype_common(pytype: str) -> str:
|
||||||
if "number" in pytype:
|
if "number" in pytype:
|
||||||
return pytype.replace("number", "Union[int, float, complex]")
|
return pytype.replace("number", "Union[int, float, complex]")
|
||||||
|
@ -65,6 +68,7 @@ def _pytype_to_fn_pytype_common(pytype: str) -> str:
|
||||||
return "Any"
|
return "Any"
|
||||||
return pytype
|
return pytype
|
||||||
|
|
||||||
|
|
||||||
def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
||||||
"""Convert a JitOperator pytype to the type relevant in shape functions.
|
"""Convert a JitOperator pytype to the type relevant in shape functions.
|
||||||
|
|
||||||
|
@ -86,6 +90,7 @@ def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
||||||
return pytype.replace("Tensor", "List[int]")
|
return pytype.replace("Tensor", "List[int]")
|
||||||
return _pytype_to_fn_pytype_common(pytype)
|
return _pytype_to_fn_pytype_common(pytype)
|
||||||
|
|
||||||
|
|
||||||
def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
|
def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
|
||||||
"""Convert a JitOperator pytype to the type relevant in dtype functions.
|
"""Convert a JitOperator pytype to the type relevant in dtype functions.
|
||||||
|
|
||||||
|
@ -97,13 +102,15 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
|
||||||
return pytype.replace("Tensor", "Tuple[int, int]")
|
return pytype.replace("Tensor", "Tuple[int, int]")
|
||||||
return _pytype_to_fn_pytype_common(pytype)
|
return _pytype_to_fn_pytype_common(pytype)
|
||||||
|
|
||||||
|
|
||||||
def _pytype_to_decomposition_fn_pytype(pytype: str) -> str:
|
def _pytype_to_decomposition_fn_pytype(pytype: str) -> str:
|
||||||
"""Convert a JitOperator pytype to the type relevant in decomposition functions.
|
"""Convert a JitOperator pytype to the type relevant in decomposition functions."""
|
||||||
"""
|
|
||||||
return _pytype_to_fn_pytype_common(pytype)
|
return _pytype_to_fn_pytype_common(pytype)
|
||||||
|
|
||||||
|
|
||||||
class JitOperator:
|
class JitOperator:
|
||||||
"""Information about a single registered `torch::jit::Operator`"""
|
"""Information about a single registered `torch::jit::Operator`"""
|
||||||
|
|
||||||
def __init__(self, op_info: "OP_INFO_DICT"):
|
def __init__(self, op_info: "OP_INFO_DICT"):
|
||||||
"""Create a JitOperator from the raw OP_INFO_DICT extracted from
|
"""Create a JitOperator from the raw OP_INFO_DICT extracted from
|
||||||
the PyTorch JIT operator registry.
|
the PyTorch JIT operator registry.
|
||||||
|
@ -170,6 +177,7 @@ class JitOperator:
|
||||||
are useful in the repr for cross referencing, and it's useful to have
|
are useful in the repr for cross referencing, and it's useful to have
|
||||||
them in a single point of truth.
|
them in a single point of truth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def uppercase_first_letter(s):
|
def uppercase_first_letter(s):
|
||||||
if not s:
|
if not s:
|
||||||
return s
|
return s
|
||||||
|
@ -184,15 +192,19 @@ class JitOperator:
|
||||||
for op_name_atom in op_name_atoms:
|
for op_name_atom in op_name_atoms:
|
||||||
for s in op_name_atom.split("_"):
|
for s in op_name_atom.split("_"):
|
||||||
op_class_name_atoms.append(s if s else "_")
|
op_class_name_atoms.append(s if s else "_")
|
||||||
cpp_class_name = "".join(
|
cpp_class_name = (
|
||||||
uppercase_first_letter(s) for s in op_class_name_atoms) + "Op"
|
"".join(uppercase_first_letter(s) for s in op_class_name_atoms) + "Op"
|
||||||
|
)
|
||||||
# Disallow leading underscores in C++ to avoid illegal names.
|
# Disallow leading underscores in C++ to avoid illegal names.
|
||||||
cpp_class_name = cpp_class_name.lstrip("_")
|
cpp_class_name = cpp_class_name.lstrip("_")
|
||||||
return op_name, cpp_class_name
|
return op_name, cpp_class_name
|
||||||
|
|
||||||
def _get_function_signature(self, function_kind: str,
|
def _get_function_signature(
|
||||||
|
self,
|
||||||
|
function_kind: str,
|
||||||
parameter_decl_builder: Callable[["SIG_ATTR_TYPE"], str],
|
parameter_decl_builder: Callable[["SIG_ATTR_TYPE"], str],
|
||||||
ret_decl_builder: Callable[["SIG_ATTR_TYPE"], str]) -> str:
|
ret_decl_builder: Callable[["SIG_ATTR_TYPE"], str],
|
||||||
|
) -> str:
|
||||||
mlir_op_name, _ = self.get_mlir_names()
|
mlir_op_name, _ = self.get_mlir_names()
|
||||||
# Replace `.` with a valid Python identifier character.
|
# Replace `.` with a valid Python identifier character.
|
||||||
# `〇` vaguely looks like `.`.
|
# `〇` vaguely looks like `.`.
|
||||||
|
@ -219,6 +231,7 @@ class JitOperator:
|
||||||
ops have extra default arguments and stuff that are tedious to write out
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
right.
|
right.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
|
pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||||
default = _get_default_value(arg)
|
default = _get_default_value(arg)
|
||||||
|
@ -229,7 +242,8 @@ class JitOperator:
|
||||||
return _pytype_to_shape_fn_pytype(arg["pytype"])
|
return _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||||
|
|
||||||
return self._get_function_signature(
|
return self._get_function_signature(
|
||||||
"shape", parameter_decl_builder, ret_decl_builder)
|
"shape", parameter_decl_builder, ret_decl_builder
|
||||||
|
)
|
||||||
|
|
||||||
def get_dtype_function_signature(self):
|
def get_dtype_function_signature(self):
|
||||||
"""Gets the Python function signature for this op's dtype function.
|
"""Gets the Python function signature for this op's dtype function.
|
||||||
|
@ -239,6 +253,7 @@ class JitOperator:
|
||||||
ops have extra default arguments and stuff that are tedious to write out
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
right.
|
right.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
pytype = _pytype_to_dtype_fn_pytype(arg["pytype"])
|
pytype = _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||||
default = _get_default_value(arg)
|
default = _get_default_value(arg)
|
||||||
|
@ -257,7 +272,8 @@ class JitOperator:
|
||||||
return _pytype_to_dtype_fn_pytype(arg["pytype"])
|
return _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||||
|
|
||||||
return self._get_function_signature(
|
return self._get_function_signature(
|
||||||
"dtype", parameter_decl_builder, ret_decl_builder)
|
"dtype", parameter_decl_builder, ret_decl_builder
|
||||||
|
)
|
||||||
|
|
||||||
def get_decomposition_function_signature(self):
|
def get_decomposition_function_signature(self):
|
||||||
"""Gets the Python function signature for this op's decomposition function.
|
"""Gets the Python function signature for this op's decomposition function.
|
||||||
|
@ -267,6 +283,7 @@ class JitOperator:
|
||||||
ops have extra default arguments and stuff that are tedious to write out
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
right.
|
right.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||||
default = _get_default_value(arg)
|
default = _get_default_value(arg)
|
||||||
|
@ -277,7 +294,8 @@ class JitOperator:
|
||||||
return _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
return _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||||
|
|
||||||
return self._get_function_signature(
|
return self._get_function_signature(
|
||||||
"decomposition", parameter_decl_builder, ret_decl_builder)
|
"decomposition", parameter_decl_builder, ret_decl_builder
|
||||||
|
)
|
||||||
|
|
||||||
def get_has_value_semantics_function_signature(self):
|
def get_has_value_semantics_function_signature(self):
|
||||||
"""Gets the Python function signature for this op's has_value_semantics function.
|
"""Gets the Python function signature for this op's has_value_semantics function.
|
||||||
|
@ -287,6 +305,7 @@ class JitOperator:
|
||||||
ops have extra default arguments and stuff that are tedious to write out
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
right.
|
right.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -294,7 +313,8 @@ class JitOperator:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return self._get_function_signature(
|
return self._get_function_signature(
|
||||||
"has_value_semantics", parameter_decl_builder, ret_decl_builder)
|
"has_value_semantics", parameter_decl_builder, ret_decl_builder
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
f = io.StringIO()
|
f = io.StringIO()
|
||||||
|
@ -318,7 +338,9 @@ class JitOperator:
|
||||||
p(f"is_mutable = {self.is_mutable}")
|
p(f"is_mutable = {self.is_mutable}")
|
||||||
if any(ret["type"] == "Tensor" for ret in self.returns):
|
if any(ret["type"] == "Tensor" for ret in self.returns):
|
||||||
p(f"shape_function_signature = {self.get_shape_function_signature()}")
|
p(f"shape_function_signature = {self.get_shape_function_signature()}")
|
||||||
p(f"decomposition_function_signature = {self.get_decomposition_function_signature()}")
|
p(
|
||||||
|
f"decomposition_function_signature = {self.get_decomposition_function_signature()}"
|
||||||
|
)
|
||||||
if any(ret["type"] in ["Tensor", "Scalar"] for ret in self.returns):
|
if any(ret["type"] in ["Tensor", "Scalar"] for ret in self.returns):
|
||||||
p(f"dtype_function_signature = {self.get_dtype_function_signature()}")
|
p(f"dtype_function_signature = {self.get_dtype_function_signature()}")
|
||||||
|
|
||||||
|
@ -354,7 +376,9 @@ class JitOperator:
|
||||||
# Note that this is different from MLIR's NoSideEffect which is much
|
# Note that this is different from MLIR's NoSideEffect which is much
|
||||||
# stronger (for example, it cannot be applied to ops that might emit errors
|
# stronger (for example, it cannot be applied to ops that might emit errors
|
||||||
# when operand shapes mismatch).
|
# when operand shapes mismatch).
|
||||||
if any("alias_info" in x for x in itertools.chain(self.arguments, self.returns)):
|
if any(
|
||||||
|
"alias_info" in x for x in itertools.chain(self.arguments, self.returns)
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
# It seems the FunctionSchema of "prim::unchecked_cast : (t) -> (t)" has
|
# It seems the FunctionSchema of "prim::unchecked_cast : (t) -> (t)" has
|
||||||
# incorrect alias information. The result can alias with other tensors
|
# incorrect alias information. The result can alias with other tensors
|
||||||
|
@ -363,8 +387,10 @@ class JitOperator:
|
||||||
return False
|
return False
|
||||||
# The `is` operator compares object identity, so it does not have
|
# The `is` operator compares object identity, so it does not have
|
||||||
# value semantics.
|
# value semantics.
|
||||||
if self.unique_key in ("aten::__is__ : (t1, t2) -> (bool)",
|
if self.unique_key in (
|
||||||
"aten::__isnot__ : (t1, t2) -> (bool)"):
|
"aten::__is__ : (t1, t2) -> (bool)",
|
||||||
|
"aten::__isnot__ : (t1, t2) -> (bool)",
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -390,6 +416,7 @@ class JitOperator:
|
||||||
|
|
||||||
class Registry:
|
class Registry:
|
||||||
"""An indexed collection of JitOperators"""
|
"""An indexed collection of JitOperators"""
|
||||||
|
|
||||||
def __init__(self, operators: List[JitOperator]):
|
def __init__(self, operators: List[JitOperator]):
|
||||||
self.by_unique_key = {}
|
self.by_unique_key = {}
|
||||||
self.by_triple = {}
|
self.by_triple = {}
|
||||||
|
@ -434,4 +461,3 @@ SIGLIST_TYPE = List[SIG_ATTR_TYPE]
|
||||||
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
|
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
|
||||||
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
|
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
|
||||||
OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]]
|
OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]]
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,7 @@ from torch import Tensor
|
||||||
# The typical iteration flow is to add invocations to the list and then re-run
|
# The typical iteration flow is to add invocations to the list and then re-run
|
||||||
# `build_tools/update_abstract_interp_lib.sh` to re-run the tests.
|
# `build_tools/update_abstract_interp_lib.sh` to re-run the tests.
|
||||||
|
|
||||||
|
|
||||||
class TensorOfShape:
|
class TensorOfShape:
|
||||||
"""Symbolic placeholder for a tensor argument to an operation.
|
"""Symbolic placeholder for a tensor argument to an operation.
|
||||||
|
|
||||||
|
@ -60,30 +61,40 @@ class TensorOfShape:
|
||||||
This class also tracks a dtype of the tensor, since some ops require a
|
This class also tracks a dtype of the tensor, since some ops require a
|
||||||
specific dtype.
|
specific dtype.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32,
|
|
||||||
device: Optional[torch.device] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*shape: int,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
self.shape = list(shape)
|
self.shape = list(shape)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = "meta" if device is None else device
|
self.device = "meta" if device is None else device
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
args_str = ", ".join(repr(x) for x in self.shape)
|
args_str = ", ".join(repr(x) for x in self.shape)
|
||||||
return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})"
|
return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})"
|
||||||
|
|
||||||
|
|
||||||
def LongTensorOfShape(*args, **kwargs):
|
def LongTensorOfShape(*args, **kwargs):
|
||||||
"""Helper for indicating a TensorOfShape with integer type."""
|
"""Helper for indicating a TensorOfShape with integer type."""
|
||||||
return TensorOfShape(*args, **kwargs, dtype=torch.long)
|
return TensorOfShape(*args, **kwargs, dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
def NonZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
def NonZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
||||||
"""Helper for indicating a non-zero dim tensor with custom type."""
|
"""Helper for indicating a non-zero dim tensor with custom type."""
|
||||||
return TensorOfShape(1, dtype=dtype, device=device)
|
return TensorOfShape(1, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
def ZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
def ZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None):
|
||||||
"""Helper for indicating a zero dim tensor with custom type."""
|
"""Helper for indicating a zero dim tensor with custom type."""
|
||||||
return TensorOfShape(dtype=dtype, device=device)
|
return TensorOfShape(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
def _recursively_transform_tensor_args(
|
def _recursively_transform_tensor_args(
|
||||||
o: Any,
|
o: Any, tensor_transformer: Callable[[TensorOfShape], Any]
|
||||||
tensor_transformer: Callable[[TensorOfShape], Any]) -> Any:
|
) -> Any:
|
||||||
"""Replace `TensorOfShape` with the result of `tensor_transformer`"""
|
"""Replace `TensorOfShape` with the result of `tensor_transformer`"""
|
||||||
if o is None or isinstance(o, (float, int, str)):
|
if o is None or isinstance(o, (float, int, str)):
|
||||||
return o
|
return o
|
||||||
|
@ -92,9 +103,12 @@ def _recursively_transform_tensor_args(
|
||||||
if isinstance(o, list):
|
if isinstance(o, list):
|
||||||
return [_recursively_transform_tensor_args(x, tensor_transformer) for x in o]
|
return [_recursively_transform_tensor_args(x, tensor_transformer) for x in o]
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o)
|
return tuple(
|
||||||
|
_recursively_transform_tensor_args(x, tensor_transformer) for x in o
|
||||||
|
)
|
||||||
raise Exception(f"Unhandled type {type(o)}")
|
raise Exception(f"Unhandled type {type(o)}")
|
||||||
|
|
||||||
|
|
||||||
class Invocation:
|
class Invocation:
|
||||||
"""Representation of a single op invocation (i.e. list of args to the op).
|
"""Representation of a single op invocation (i.e. list of args to the op).
|
||||||
|
|
||||||
|
@ -111,6 +125,7 @@ class Invocation:
|
||||||
exception for greater precision when interpreting errors raised during
|
exception for greater precision when interpreting errors raised during
|
||||||
testing.
|
testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any):
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
self.args = list(args)
|
self.args = list(args)
|
||||||
# We assume kwargs don't contain tensors, so they don't need any
|
# We assume kwargs don't contain tensors, so they don't need any
|
||||||
|
@ -134,14 +149,12 @@ class Invocation:
|
||||||
# are ok since they make it a bit easier to write some shape
|
# are ok since they make it a bit easier to write some shape
|
||||||
# functions.
|
# functions.
|
||||||
tensor_transformer = lambda o: list(o.shape)
|
tensor_transformer = lambda o: list(o.shape)
|
||||||
return _recursively_transform_tensor_args(
|
return _recursively_transform_tensor_args(self.args, tensor_transformer)
|
||||||
self.args, tensor_transformer)
|
|
||||||
|
|
||||||
def to_dtype_function_args(self):
|
def to_dtype_function_args(self):
|
||||||
"""Gets positional arguments appropriate for a dtype function."""
|
"""Gets positional arguments appropriate for a dtype function."""
|
||||||
tensor_transformer = lambda o: (len(o.shape), o.dtype)
|
tensor_transformer = lambda o: (len(o.shape), o.dtype)
|
||||||
return _recursively_transform_tensor_args(
|
return _recursively_transform_tensor_args(self.args, tensor_transformer)
|
||||||
self.args, tensor_transformer)
|
|
||||||
|
|
||||||
def to_real_op_args(self):
|
def to_real_op_args(self):
|
||||||
"""Gets positional arguments appropriate for the real op."""
|
"""Gets positional arguments appropriate for the real op."""
|
||||||
|
@ -155,6 +168,7 @@ class Invocation:
|
||||||
kwargs_str = ", " + ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
|
kwargs_str = ", " + ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
|
||||||
return f"Invocation({args_str}{kwargs_str})"
|
return f"Invocation({args_str}{kwargs_str})"
|
||||||
|
|
||||||
|
|
||||||
class ErrorInvocation(Invocation):
|
class ErrorInvocation(Invocation):
|
||||||
"""An Invocation that raises an exception.
|
"""An Invocation that raises an exception.
|
||||||
|
|
||||||
|
@ -165,9 +179,11 @@ class ErrorInvocation(Invocation):
|
||||||
spurioiusly make the two appear to "agree" that an exception needs to be
|
spurioiusly make the two appear to "agree" that an exception needs to be
|
||||||
raised).
|
raised).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_expected_to_raise_exception(self) -> bool:
|
def is_expected_to_raise_exception(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _normalize_multiple_results_to_list(t: Any):
|
def _normalize_multiple_results_to_list(t: Any):
|
||||||
"""Returns a flat list of results.
|
"""Returns a flat list of results.
|
||||||
|
|
||||||
|
@ -182,9 +198,13 @@ def _normalize_multiple_results_to_list(t: Any):
|
||||||
return [t]
|
return [t]
|
||||||
raise ValueError(f"Unexpected type {type(t)}")
|
raise ValueError(f"Unexpected type {type(t)}")
|
||||||
|
|
||||||
|
|
||||||
def _report(f, invocation: Invocation, error_message: str):
|
def _report(f, invocation: Invocation, error_message: str):
|
||||||
fn_type = f.__name__.split("〡")[-1]
|
fn_type = f.__name__.split("〡")[-1]
|
||||||
raise ValueError(f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}")
|
raise ValueError(
|
||||||
|
f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_fn_and_golden_results(f, invocation: List[Invocation]):
|
def _get_fn_and_golden_results(f, invocation: List[Invocation]):
|
||||||
"""Run the invocation on the library function and torch op.
|
"""Run the invocation on the library function and torch op.
|
||||||
|
@ -201,36 +221,64 @@ def _get_fn_and_golden_results(f, invocation: List[Invocation]):
|
||||||
op = getattr(getattr(getattr(torch.ops, ns), unqual), overload)
|
op = getattr(getattr(getattr(torch.ops, ns), unqual), overload)
|
||||||
fn_error, op_error, fn_results, golden_results = None, None, None, None
|
fn_error, op_error, fn_results, golden_results = None, None, None, None
|
||||||
try:
|
try:
|
||||||
fn_results = _normalize_multiple_results_to_list(f(
|
fn_results = _normalize_multiple_results_to_list(
|
||||||
|
f(
|
||||||
*(getattr(invocation, f"to_{fn_type}_function_args")()),
|
*(getattr(invocation, f"to_{fn_type}_function_args")()),
|
||||||
**invocation.kwargs))
|
**invocation.kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
fn_error = f"{e}"
|
fn_error = f"{e}"
|
||||||
try:
|
try:
|
||||||
golden_results = _normalize_multiple_results_to_list(op(
|
golden_results = _normalize_multiple_results_to_list(
|
||||||
*invocation.to_real_op_args(),
|
op(*invocation.to_real_op_args(), **invocation.kwargs)
|
||||||
**invocation.kwargs))
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
op_error = f"{e}"
|
op_error = f"{e}"
|
||||||
|
|
||||||
# Check for error behavior.
|
# Check for error behavior.
|
||||||
if invocation.is_expected_to_raise_exception():
|
if invocation.is_expected_to_raise_exception():
|
||||||
if fn_error is None and op_error is None:
|
if fn_error is None and op_error is None:
|
||||||
_report(f, invocation, f"Expected to raise an exception, but neither {fn_type} function or op raised an exception")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Expected to raise an exception, but neither {fn_type} function or op raised an exception",
|
||||||
|
)
|
||||||
if fn_error is None:
|
if fn_error is None:
|
||||||
_report(f, invocation, f"Op raised error {op_error!r}, but shape/dtype function did not.")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Op raised error {op_error!r}, but shape/dtype function did not.",
|
||||||
|
)
|
||||||
if op_error is None:
|
if op_error is None:
|
||||||
_report(f, invocation, f"{fn_type} function raised error {fn_error!r}, but op did not.")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"{fn_type} function raised error {fn_error!r}, but op did not.",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if fn_error is not None and op_error is not None:
|
if fn_error is not None and op_error is not None:
|
||||||
_report(f, invocation, f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.",
|
||||||
|
)
|
||||||
if fn_error is not None:
|
if fn_error is not None:
|
||||||
_report(f, invocation, f"{fn_type} function raised error {fn_error!r} but op did not raise any error.")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"{fn_type} function raised error {fn_error!r} but op did not raise any error.",
|
||||||
|
)
|
||||||
if op_error is not None:
|
if op_error is not None:
|
||||||
_report(f, invocation, f"Op raised error {op_error!r} but {fn_type} function did not raise any error.")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Op raised error {op_error!r} but {fn_type} function did not raise any error.",
|
||||||
|
)
|
||||||
|
|
||||||
return fn_results, golden_results
|
return fn_results, golden_results
|
||||||
|
|
||||||
|
|
||||||
def check_shape_function(invocations: List[Invocation]):
|
def check_shape_function(invocations: List[Invocation]):
|
||||||
"""Decorator that automatically tests a shape function.
|
"""Decorator that automatically tests a shape function.
|
||||||
|
|
||||||
|
@ -238,6 +286,7 @@ def check_shape_function(invocations: List[Invocation]):
|
||||||
`〇` instead of `.`, is tested against the corresponding op in
|
`〇` instead of `.`, is tested against the corresponding op in
|
||||||
`torch.ops.*` function using the given invocations.
|
`torch.ops.*` function using the given invocations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
for invocation in invocations:
|
for invocation in invocations:
|
||||||
result_shapes, golden_results = _get_fn_and_golden_results(f, invocation)
|
result_shapes, golden_results = _get_fn_and_golden_results(f, invocation)
|
||||||
|
@ -245,18 +294,34 @@ def check_shape_function(invocations: List[Invocation]):
|
||||||
continue
|
continue
|
||||||
# Check for matching results.
|
# Check for matching results.
|
||||||
if len(result_shapes) != len(golden_results):
|
if len(result_shapes) != len(golden_results):
|
||||||
_report(f, invocation, f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}",
|
||||||
|
)
|
||||||
for result_shape, golden_result in zip(result_shapes, golden_results):
|
for result_shape, golden_result in zip(result_shapes, golden_results):
|
||||||
result_rank = len(result_shape)
|
result_rank = len(result_shape)
|
||||||
golden_rank = len(golden_result.shape)
|
golden_rank = len(golden_result.shape)
|
||||||
if result_rank != golden_rank:
|
if result_rank != golden_rank:
|
||||||
_report(f, invocation, f"Expected result rank {golden_rank}, got {result_rank}")
|
_report(
|
||||||
for dimension_size, golden_dimension_size in zip(result_shape, golden_result.shape):
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Expected result rank {golden_rank}, got {result_rank}",
|
||||||
|
)
|
||||||
|
for dimension_size, golden_dimension_size in zip(
|
||||||
|
result_shape, golden_result.shape
|
||||||
|
):
|
||||||
if dimension_size != golden_dimension_size:
|
if dimension_size != golden_dimension_size:
|
||||||
_report(f, invocation, f"Expected result shape {golden_result.shape}, got {result_shape}")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Expected result shape {golden_result.shape}, got {result_shape}",
|
||||||
|
)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def _convert_dtype_to_int(dtype: torch.dtype) -> int:
|
def _convert_dtype_to_int(dtype: torch.dtype) -> int:
|
||||||
"""Convert a PyTorch `dtype` into its underlying `int` representation.
|
"""Convert a PyTorch `dtype` into its underlying `int` representation.
|
||||||
|
@ -266,6 +331,7 @@ def _convert_dtype_to_int(dtype: torch.dtype) -> int:
|
||||||
"""
|
"""
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
def check_dtype_function(invocations: List[Invocation]):
|
def check_dtype_function(invocations: List[Invocation]):
|
||||||
"""Decorator that automatically tests a dtype function.
|
"""Decorator that automatically tests a dtype function.
|
||||||
|
|
||||||
|
@ -273,6 +339,7 @@ def check_dtype_function(invocations: List[Invocation]):
|
||||||
`〇` instead of `.`, is tested against the corresponding op in
|
`〇` instead of `.`, is tested against the corresponding op in
|
||||||
`torch.ops.*` function using the given invocations.
|
`torch.ops.*` function using the given invocations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
for invocation in invocations:
|
for invocation in invocations:
|
||||||
result_dtypes, golden_results = _get_fn_and_golden_results(f, invocation)
|
result_dtypes, golden_results = _get_fn_and_golden_results(f, invocation)
|
||||||
|
@ -280,7 +347,11 @@ def check_dtype_function(invocations: List[Invocation]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(result_dtypes) != len(golden_results):
|
if len(result_dtypes) != len(golden_results):
|
||||||
_report(f, invocation, f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}")
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}",
|
||||||
|
)
|
||||||
for result_dtype, golden_result in zip(result_dtypes, golden_results):
|
for result_dtype, golden_result in zip(result_dtypes, golden_results):
|
||||||
if isinstance(golden_result, torch.Tensor):
|
if isinstance(golden_result, torch.Tensor):
|
||||||
golden_dtype = golden_result.dtype
|
golden_dtype = golden_result.dtype
|
||||||
|
@ -294,7 +365,14 @@ def check_dtype_function(invocations: List[Invocation]):
|
||||||
# support returning the default `int` value, the comparisons of
|
# support returning the default `int` value, the comparisons of
|
||||||
# the result and golden dtypes are done using their underlying
|
# the result and golden dtypes are done using their underlying
|
||||||
# `int` representation.
|
# `int` representation.
|
||||||
if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype):
|
if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(
|
||||||
_report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}")
|
golden_dtype
|
||||||
|
):
|
||||||
|
_report(
|
||||||
|
f,
|
||||||
|
invocation,
|
||||||
|
f"Expected result dtype {golden_dtype}, got {result_dtype}",
|
||||||
|
)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
|
@ -71,18 +71,23 @@ def get_ods_type(type: str, non_value: bool, *, is_result: bool = False):
|
||||||
if type.startswith("Dict("):
|
if type.startswith("Dict("):
|
||||||
type = "Dict"
|
type = "Dict"
|
||||||
if non_value:
|
if non_value:
|
||||||
ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get(type) or TORCH_TYPE_TO_ODS_TYPE.get(type)
|
ods_type = TORCH_NON_VALUE_TYPE_TO_ODS_TYPE.get(
|
||||||
|
type
|
||||||
|
) or TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||||
else:
|
else:
|
||||||
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
|
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||||
if ods_type is None:
|
if ods_type is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!")
|
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!"
|
||||||
|
)
|
||||||
return ods_type
|
return ods_type
|
||||||
|
|
||||||
|
|
||||||
def _name_thunk() -> None:
|
def _name_thunk() -> None:
|
||||||
# Strictly exists for _get_main_module_name to harvest its __module__.
|
# Strictly exists for _get_main_module_name to harvest its __module__.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_main_module_name() -> str:
|
def _get_main_module_name() -> str:
|
||||||
# If a Python module is loaded interactively or as part of a module
|
# If a Python module is loaded interactively or as part of a module
|
||||||
# directory, it uses a BuiltinImporter. If loaded from a file, it uses
|
# directory, it uses a BuiltinImporter. If loaded from a file, it uses
|
||||||
|
@ -93,6 +98,7 @@ def _get_main_module_name() -> str:
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return _name_thunk.__module__
|
return _name_thunk.__module__
|
||||||
|
|
||||||
|
|
||||||
ODS_BANNER = f"""//===-------------------------------------------------------*- tablegen -*-===//
|
ODS_BANNER = f"""//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
//
|
//
|
||||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
@ -117,10 +123,15 @@ ODS_BANNER = f"""//===-------------------------------------------------------*-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def raw_emit_op(operator: JitOperator,
|
def raw_emit_op(
|
||||||
|
operator: JitOperator,
|
||||||
emitter_td: TextEmitter,
|
emitter_td: TextEmitter,
|
||||||
*, traits: List[str],
|
*,
|
||||||
has_folder: bool, has_canonicalizer: bool, has_verifier: bool):
|
traits: List[str],
|
||||||
|
has_folder: bool,
|
||||||
|
has_canonicalizer: bool,
|
||||||
|
has_verifier: bool,
|
||||||
|
):
|
||||||
"""Emit the ODS for a JitOperator to a textual file.
|
"""Emit the ODS for a JitOperator to a textual file.
|
||||||
|
|
||||||
This is the lowest level of emission and is responsible for low-level
|
This is the lowest level of emission and is responsible for low-level
|
||||||
|
@ -138,8 +149,7 @@ def raw_emit_op(operator: JitOperator,
|
||||||
def generic_result_name(i):
|
def generic_result_name(i):
|
||||||
return "result" + (str(i) if multiple_results else "")
|
return "result" + (str(i) if multiple_results else "")
|
||||||
|
|
||||||
p_td(
|
p_td(f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [")
|
||||||
f"def Torch_{cpp_class_name} : Torch_Op<{emitter_td.quote(op_name)}, [")
|
|
||||||
with emitter_td.indent():
|
with emitter_td.indent():
|
||||||
with emitter_td.indent():
|
with emitter_td.indent():
|
||||||
p_td(",\n".join(traits))
|
p_td(",\n".join(traits))
|
||||||
|
@ -153,20 +163,28 @@ def raw_emit_op(operator: JitOperator,
|
||||||
if operator.is_vararg:
|
if operator.is_vararg:
|
||||||
p_td("Variadic<AnyTorchType>:$operands")
|
p_td("Variadic<AnyTorchType>:$operands")
|
||||||
else:
|
else:
|
||||||
p_td(",\n".join([
|
p_td(
|
||||||
|
",\n".join(
|
||||||
|
[
|
||||||
f"""{get_ods_type(arg["type"], is_non_value_op)}:${arg["name"]}"""
|
f"""{get_ods_type(arg["type"], is_non_value_op)}:${arg["name"]}"""
|
||||||
for arg in operator.arguments
|
for arg in operator.arguments
|
||||||
]))
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
p_td(");")
|
p_td(");")
|
||||||
p_td(f"let results = (outs")
|
p_td(f"let results = (outs")
|
||||||
with emitter_td.indent():
|
with emitter_td.indent():
|
||||||
if operator.is_varret:
|
if operator.is_varret:
|
||||||
p_td("Variadic<AnyTorchType>:$results")
|
p_td("Variadic<AnyTorchType>:$results")
|
||||||
else:
|
else:
|
||||||
p_td(",\n".join([
|
p_td(
|
||||||
|
",\n".join(
|
||||||
|
[
|
||||||
f"""{get_ods_type(ret["type"], is_non_value_op, is_result=True)}:${ret["name"] or generic_result_name(e)}"""
|
f"""{get_ods_type(ret["type"], is_non_value_op, is_result=True)}:${ret["name"] or generic_result_name(e)}"""
|
||||||
for e, ret in enumerate(operator.returns)
|
for e, ret in enumerate(operator.returns)
|
||||||
]))
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
p_td(");")
|
p_td(");")
|
||||||
|
|
||||||
if operator.is_vararg or operator.is_varret:
|
if operator.is_vararg or operator.is_varret:
|
||||||
|
@ -174,16 +192,19 @@ def raw_emit_op(operator: JitOperator,
|
||||||
assembly_operands = "`(` $operands `)`"
|
assembly_operands = "`(` $operands `)`"
|
||||||
assembly_operand_types = "qualified(type($operands))"
|
assembly_operand_types = "qualified(type($operands))"
|
||||||
else:
|
else:
|
||||||
assembly_operands = " `,` ".join("$" + arg["name"]
|
assembly_operands = " `,` ".join(
|
||||||
for arg in operator.arguments)
|
"$" + arg["name"] for arg in operator.arguments
|
||||||
|
)
|
||||||
assembly_operand_types = " `,` ".join(
|
assembly_operand_types = " `,` ".join(
|
||||||
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments)
|
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments
|
||||||
|
)
|
||||||
if operator.is_varret:
|
if operator.is_varret:
|
||||||
assembly_result_types = "qualified(type($results))"
|
assembly_result_types = "qualified(type($results))"
|
||||||
else:
|
else:
|
||||||
assembly_result_types = " `,` ".join(
|
assembly_result_types = " `,` ".join(
|
||||||
f"""qualified(type(${ret["name"] or generic_result_name(e)}))"""
|
f"""qualified(type(${ret["name"] or generic_result_name(e)}))"""
|
||||||
for e, ret in enumerate(operator.returns))
|
for e, ret in enumerate(operator.returns)
|
||||||
|
)
|
||||||
if assembly_operand_types and assembly_result_types:
|
if assembly_operand_types and assembly_result_types:
|
||||||
maybe_arrow = " `->` "
|
maybe_arrow = " `->` "
|
||||||
else:
|
else:
|
||||||
|
@ -192,7 +213,8 @@ def raw_emit_op(operator: JitOperator,
|
||||||
p_td(f"let assemblyFormat = {emitter_td.quote(assembly_format)};")
|
p_td(f"let assemblyFormat = {emitter_td.quote(assembly_format)};")
|
||||||
else:
|
else:
|
||||||
p_td(f"let hasCustomAssemblyFormat = 1;")
|
p_td(f"let hasCustomAssemblyFormat = 1;")
|
||||||
p_td(f"""let extraClassDefinition = [{{
|
p_td(
|
||||||
|
f"""let extraClassDefinition = [{{
|
||||||
ParseResult {cpp_class_name}::parse(OpAsmParser &parser, OperationState &result) {{
|
ParseResult {cpp_class_name}::parse(OpAsmParser &parser, OperationState &result) {{
|
||||||
return parseDefaultTorchOp(parser, result, {len(operator.arguments)}, {len(operator.returns)});
|
return parseDefaultTorchOp(parser, result, {len(operator.arguments)}, {len(operator.returns)});
|
||||||
}}
|
}}
|
||||||
|
@ -200,7 +222,8 @@ def raw_emit_op(operator: JitOperator,
|
||||||
printDefaultTorchOp(printer, *this, {len(operator.arguments)}, {len(operator.returns)});
|
printDefaultTorchOp(printer, *this, {len(operator.arguments)}, {len(operator.returns)});
|
||||||
}}
|
}}
|
||||||
}}];
|
}}];
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
if has_folder:
|
if has_folder:
|
||||||
p_td("let hasFolder = 1;")
|
p_td("let hasFolder = 1;")
|
||||||
if has_canonicalizer:
|
if has_canonicalizer:
|
||||||
|
@ -211,13 +234,15 @@ def raw_emit_op(operator: JitOperator,
|
||||||
p_td("\n")
|
p_td("\n")
|
||||||
|
|
||||||
|
|
||||||
def emit_op(operator: JitOperator,
|
def emit_op(
|
||||||
|
operator: JitOperator,
|
||||||
emitter_td: TextEmitter,
|
emitter_td: TextEmitter,
|
||||||
*,
|
*,
|
||||||
traits: Optional[List[str]] = None,
|
traits: Optional[List[str]] = None,
|
||||||
has_folder: bool = False,
|
has_folder: bool = False,
|
||||||
has_canonicalizer: bool = False,
|
has_canonicalizer: bool = False,
|
||||||
has_verifier: bool = False):
|
has_verifier: bool = False,
|
||||||
|
):
|
||||||
"""Main entry point for op emission.
|
"""Main entry point for op emission.
|
||||||
|
|
||||||
Besides emitting the op, it deduces / adds traits based on the operator
|
Besides emitting the op, it deduces / adds traits based on the operator
|
||||||
|
@ -233,12 +258,14 @@ def emit_op(operator: JitOperator,
|
||||||
if operator.is_readonly():
|
if operator.is_readonly():
|
||||||
traits += ["ReadOnly"]
|
traits += ["ReadOnly"]
|
||||||
|
|
||||||
raw_emit_op(operator,
|
raw_emit_op(
|
||||||
|
operator,
|
||||||
emitter_td,
|
emitter_td,
|
||||||
traits=traits,
|
traits=traits,
|
||||||
has_folder=has_folder,
|
has_folder=has_folder,
|
||||||
has_canonicalizer=has_canonicalizer,
|
has_canonicalizer=has_canonicalizer,
|
||||||
has_verifier=has_verifier)
|
has_verifier=has_verifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
|
@ -253,9 +280,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
ns, unqual, overload = operator.triple
|
ns, unqual, overload = operator.triple
|
||||||
# Underscore variant of functional ops should have "functional" part removed.
|
# Underscore variant of functional ops should have "functional" part removed.
|
||||||
is_functional_op = overload == "functional"
|
is_functional_op = overload == "functional"
|
||||||
emit_op(registry.get_by_triple((ns, unqual + "_", overload if not is_functional_op else "")),
|
emit_op(
|
||||||
|
registry.get_by_triple(
|
||||||
|
(ns, unqual + "_", overload if not is_functional_op else "")
|
||||||
|
),
|
||||||
emitter_td,
|
emitter_td,
|
||||||
traits=["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [])
|
traits=["IsTrailingUnderscoreInplaceVariant"]
|
||||||
|
if not is_functional_op
|
||||||
|
else [],
|
||||||
|
)
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# `aten::` namespace.
|
# `aten::` namespace.
|
||||||
|
@ -332,45 +365,105 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::square : (Tensor) -> (Tensor)",
|
"aten::square : (Tensor) -> (Tensor)",
|
||||||
"aten::zero : (Tensor) -> (Tensor)",
|
"aten::zero : (Tensor) -> (Tensor)",
|
||||||
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
|
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)"
|
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
]:
|
]:
|
||||||
emit_with_mutating_variants(key)
|
emit_with_mutating_variants(key)
|
||||||
# Shape manipulations:
|
# Shape manipulations:
|
||||||
emit_with_mutating_variants("aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants(
|
||||||
|
"aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
|
|
||||||
# Elementwise tensor compute ops that don't have the standard mutating
|
# Elementwise tensor compute ops that don't have the standard mutating
|
||||||
# variants.
|
# variants.
|
||||||
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
emit_with_mutating_variants(
|
||||||
emit_with_mutating_variants("aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", has_canonicalizer=True)
|
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
|
||||||
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
has_canonicalizer=True,
|
||||||
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
)
|
||||||
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
emit_with_mutating_variants(
|
||||||
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
"aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)",
|
||||||
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
has_canonicalizer=True,
|
||||||
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
)
|
||||||
emit_with_mutating_variants("aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants(
|
||||||
emit_with_mutating_variants("aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
emit_with_mutating_variants("aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
has_canonicalizer=True,
|
||||||
emit_with_mutating_variants("aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
has_folder=True,
|
||||||
emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
)
|
||||||
emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants(
|
||||||
emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
|
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
emit_with_mutating_variants(
|
||||||
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
"aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
)
|
||||||
|
|
||||||
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
emit_with_mutating_variants(
|
||||||
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
"aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)"
|
||||||
emit("aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)")
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::mish : (Tensor) -> (Tensor)")
|
emit("aten::mish : (Tensor) -> (Tensor)")
|
||||||
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
emit(
|
||||||
|
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
)
|
||||||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
@ -392,7 +485,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])")
|
emit("aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])")
|
||||||
|
|
||||||
# Random number generation
|
# Random number generation
|
||||||
emit_with_mutating_variants("aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)")
|
emit_with_mutating_variants(
|
||||||
|
"aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::rand : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
|
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
|
||||||
|
@ -400,11 +495,17 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
|
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
|
||||||
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
|
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
|
||||||
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
|
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
|
||||||
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit(
|
||||||
|
"aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
|
emit_with_mutating_variants(
|
||||||
|
"aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)")
|
emit(
|
||||||
|
"aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::random : (Tensor, Generator?) -> (Tensor)")
|
emit("aten::random : (Tensor, Generator?) -> (Tensor)")
|
||||||
emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)")
|
emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)")
|
||||||
|
@ -412,10 +513,14 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
|
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
|
||||||
emit_with_mutating_variants(
|
emit_with_mutating_variants(
|
||||||
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit_with_mutating_variants(
|
emit_with_mutating_variants(
|
||||||
"aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
|
"aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)"
|
||||||
emit("aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
|
)
|
||||||
|
emit(
|
||||||
|
"aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
|
|
||||||
# Non-elementwise tensor compute ops
|
# Non-elementwise tensor compute ops
|
||||||
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
||||||
|
@ -433,16 +538,32 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
"aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
emit(
|
||||||
emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
"aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)"
|
||||||
emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
)
|
||||||
|
emit(
|
||||||
|
"aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
emit("aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||||
emit("aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)")
|
emit(
|
||||||
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
"aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)"
|
||||||
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
)
|
||||||
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
emit(
|
||||||
|
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
|
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
|
||||||
emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
|
emit(
|
||||||
|
"aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||||
|
)
|
||||||
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||||
|
@ -456,43 +577,33 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit("aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)")
|
||||||
'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)'
|
|
||||||
)
|
|
||||||
emit(
|
emit(
|
||||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
|
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
|
||||||
emit(
|
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
|
||||||
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
|
|
||||||
)
|
|
||||||
emit(
|
emit(
|
||||||
"aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)",
|
"aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)",
|
||||||
)
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
|
||||||
)
|
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
"aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
|
||||||
)
|
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit("aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)")
|
||||||
"aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)"
|
|
||||||
)
|
|
||||||
emit(
|
emit(
|
||||||
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
@ -505,18 +616,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
"aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit("aten::softmax.int : (Tensor, int, int?) -> (Tensor)")
|
||||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
emit("aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)")
|
||||||
|
emit("aten::_log_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit_with_mutating_variants(
|
||||||
"aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)"
|
"aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit(
|
emit_with_mutating_variants(
|
||||||
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
|
"aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
|
||||||
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
|
||||||
emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)")
|
|
||||||
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
|
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||||
|
@ -549,13 +660,23 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||||
emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
||||||
emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)")
|
emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)")
|
||||||
emit("aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)")
|
emit(
|
||||||
|
"aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)"
|
||||||
|
)
|
||||||
emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)")
|
emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)")
|
||||||
emit("aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)")
|
emit("aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)")
|
||||||
emit("aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
emit(
|
||||||
emit("aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
"aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)"
|
||||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
)
|
||||||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
emit(
|
||||||
|
"aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||||
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
|
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
|
||||||
|
@ -563,17 +684,25 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||||
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||||
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||||
emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
|
emit(
|
||||||
emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)")
|
"aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::nonzero : (Tensor) -> (Tensor)")
|
emit("aten::nonzero : (Tensor) -> (Tensor)")
|
||||||
emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])")
|
emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])")
|
||||||
emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)")
|
emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)")
|
||||||
emit("aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)")
|
emit("aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)")
|
||||||
emit("aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)")
|
emit(
|
||||||
|
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
||||||
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)")
|
emit(
|
||||||
|
"aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
|
||||||
|
|
||||||
# Misc tensor ops.
|
# Misc tensor ops.
|
||||||
|
@ -590,9 +719,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||||
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
||||||
emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True)
|
emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True)
|
||||||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
emit(
|
||||||
|
"aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
emit(
|
||||||
|
"aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
|
@ -611,8 +744,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::any : (Tensor) -> (Tensor)")
|
emit("aten::any : (Tensor) -> (Tensor)")
|
||||||
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
|
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
|
||||||
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
emit(
|
||||||
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
"aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)")
|
emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)")
|
||||||
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||||
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
|
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
|
||||||
|
@ -624,19 +761,29 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
||||||
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
|
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
|
||||||
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
|
emit(
|
||||||
|
"aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True)
|
emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True)
|
emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True)
|
||||||
emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||||
emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)")
|
emit(
|
||||||
emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)")
|
"aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)"
|
||||||
|
)
|
||||||
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit(
|
||||||
|
"aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit(
|
||||||
|
"aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
||||||
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
@ -644,7 +791,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||||
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
||||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
|
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
emit_with_mutating_variants(
|
||||||
|
"aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
|
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
|
||||||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
|
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
|
||||||
|
@ -670,9 +819,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||||
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
|
||||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
emit(
|
||||||
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True)
|
"aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True
|
||||||
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)",
|
||||||
|
has_folder=True,
|
||||||
|
has_canonicalizer=True,
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
)
|
||||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||||
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
||||||
emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
|
emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
|
||||||
|
@ -681,33 +839,64 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True)
|
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
emit(
|
||||||
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)", has_folder=True)
|
"aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True)
|
has_canonicalizer=True,
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
|
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
|
||||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True)
|
emit(
|
||||||
|
"aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)",
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||||
emit("aten::cpu : (Tensor) -> (Tensor)")
|
emit("aten::cpu : (Tensor) -> (Tensor)")
|
||||||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
emit_with_mutating_variants(
|
||||||
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
|
"aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
|
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
|
||||||
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
|
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
|
||||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True)
|
emit(
|
||||||
|
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
|
||||||
|
)
|
||||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
|
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
|
||||||
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||||
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
||||||
emit("aten::t : (Tensor) -> (Tensor)")
|
emit("aten::t : (Tensor) -> (Tensor)")
|
||||||
emit("aten::numpy_T : (Tensor) -> (Tensor)")
|
emit("aten::numpy_T : (Tensor) -> (Tensor)")
|
||||||
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
emit(
|
||||||
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
"aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)",
|
||||||
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
has_folder=True,
|
||||||
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
)
|
||||||
|
emit(
|
||||||
|
"aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit_with_mutating_variants(
|
||||||
|
"aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
|
emit(
|
||||||
emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)")
|
"aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
||||||
|
|
||||||
# Functionalization ops
|
# Functionalization ops
|
||||||
|
@ -737,7 +926,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
||||||
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
||||||
emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)")
|
emit(
|
||||||
|
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
||||||
|
|
||||||
# Dict ops.
|
# Dict ops.
|
||||||
|
@ -750,7 +941,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()")
|
emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()")
|
||||||
|
|
||||||
# List ops.
|
# List ops.
|
||||||
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
emit(
|
||||||
|
"aten::cat : (Tensor[], int) -> (Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
|
has_folder=True,
|
||||||
|
)
|
||||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||||
emit("aten::append.t : (t[], t) -> (t[])")
|
emit("aten::append.t : (t[], t) -> (t[])")
|
||||||
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||||
|
@ -823,9 +1018,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||||
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
|
emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True)
|
||||||
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
|
emit("aten::__not__ : (bool) -> (bool)", has_folder=True)
|
||||||
emit("aten::len.t : (t[]) -> (int)",
|
emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True)
|
||||||
has_folder=True,
|
|
||||||
has_canonicalizer=True)
|
|
||||||
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
|
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
|
||||||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||||
emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True)
|
emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True)
|
||||||
|
@ -849,12 +1042,22 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
emit("aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||||
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
|
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
|
||||||
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)")
|
emit(
|
||||||
emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
"aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||||
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
|
)
|
||||||
emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)")
|
emit(
|
||||||
|
"aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||||
|
)
|
||||||
|
emit(
|
||||||
|
"aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)"
|
||||||
|
)
|
||||||
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
||||||
emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)")
|
emit(
|
||||||
|
"aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||||
|
|
||||||
# quantized ops
|
# quantized ops
|
||||||
|
@ -863,7 +1066,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::dequantize.self : (Tensor) -> (Tensor)")
|
emit("aten::dequantize.self : (Tensor) -> (Tensor)")
|
||||||
emit("aten::dequantize.tensor : (Tensor) -> (Tensor)")
|
emit("aten::dequantize.tensor : (Tensor) -> (Tensor)")
|
||||||
emit("aten::int_repr : (Tensor) -> (Tensor)")
|
emit("aten::int_repr : (Tensor) -> (Tensor)")
|
||||||
emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
emit(
|
||||||
|
"aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")
|
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
|
@ -881,10 +1086,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("prim::max.self_int : (int[]) -> (int)")
|
emit("prim::max.self_int : (int[]) -> (int)")
|
||||||
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
||||||
emit("prim::RaiseException : (str, str?) -> ()")
|
emit("prim::RaiseException : (str, str?) -> ()")
|
||||||
emit("prim::Uninitialized : () -> (Any)",
|
emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True, traits=["Pure"])
|
||||||
has_canonicalizer=True, traits=["Pure"])
|
emit(
|
||||||
emit("prim::unchecked_cast : (t) -> (t)", has_folder=True,
|
"prim::unchecked_cast : (t) -> (t)",
|
||||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
has_folder=True,
|
||||||
|
traits=["DeclareOpInterfaceMethods<CastOpInterface>"],
|
||||||
|
)
|
||||||
emit("prim::Print : (...) -> ()")
|
emit("prim::Print : (...) -> ()")
|
||||||
emit("prim::tolist : (...) -> (...)")
|
emit("prim::tolist : (...) -> (...)")
|
||||||
emit("prim::abs.Scalar : (Scalar) -> (Scalar)")
|
emit("prim::abs.Scalar : (Scalar) -> (Scalar)")
|
||||||
|
@ -908,13 +1115,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
|
|
||||||
emit(
|
emit(
|
||||||
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
|
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
|
||||||
traits=["HasValueSemantics"])
|
traits=["HasValueSemantics"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
||||||
for _, v in sorted(registry.by_unique_key.items()):
|
for _, v in sorted(registry.by_unique_key.items()):
|
||||||
outfile.write(repr(v))
|
outfile.write(repr(v))
|
||||||
|
|
||||||
|
|
||||||
def _maybe_import_op_extensions(args: argparse.Namespace):
|
def _maybe_import_op_extensions(args: argparse.Namespace):
|
||||||
extension_string = str.strip(args.pytorch_op_extensions)
|
extension_string = str.strip(args.pytorch_op_extensions)
|
||||||
if len(extension_string) > 0:
|
if len(extension_string) > 0:
|
||||||
|
@ -924,6 +1133,7 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
|
||||||
# importing these modules, so we don't need the return value.
|
# importing these modules, so we don't need the return value.
|
||||||
importlib.import_module(name)
|
importlib.import_module(name)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
_maybe_import_op_extensions(args)
|
_maybe_import_op_extensions(args)
|
||||||
registry = Registry.load()
|
registry = Registry.load()
|
||||||
|
@ -942,15 +1152,18 @@ def _create_argparse() -> argparse.ArgumentParser:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torch_ir_include_dir",
|
"--torch_ir_include_dir",
|
||||||
required=True,
|
required=True,
|
||||||
help="Directory in include/ containing the Torch dialect")
|
help="Directory in include/ containing the Torch dialect",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug_registry_dump",
|
"--debug_registry_dump",
|
||||||
help="File to dump the the PyTorch JIT operator registry into")
|
help="File to dump the the PyTorch JIT operator registry into",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pytorch_op_extensions",
|
"--pytorch_op_extensions",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="An optional, comma-separated list of Python modules which register additional PyTorch operators upon being imported. These modules can be used to build a torch-mlir which supports PyTorch extensions.")
|
help="An optional, comma-separated list of Python modules which register additional PyTorch operators upon being imported. These modules can be used to build a torch-mlir which supports PyTorch extensions.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,10 @@ from typing import TextIO
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
||||||
class TextEmitter:
|
class TextEmitter:
|
||||||
"""Helper for emitting text files"""
|
"""Helper for emitting text files"""
|
||||||
|
|
||||||
_INDENT = " "
|
_INDENT = " "
|
||||||
|
|
||||||
def __init__(self, out: TextIO):
|
def __init__(self, out: TextIO):
|
||||||
|
|
|
@ -24,34 +24,38 @@ from torch_mlir.jit_ir_importer import ClassAnnotator
|
||||||
|
|
||||||
# Utilities for extracting decorated information into ClassAnnotator.
|
# Utilities for extracting decorated information into ClassAnnotator.
|
||||||
|
|
||||||
|
|
||||||
def _recursively_extract_annotations(
|
def _recursively_extract_annotations(
|
||||||
module: torch.nn.Module, scripted: torch.jit.ScriptModule,
|
module: torch.nn.Module,
|
||||||
class_annotator: ClassAnnotator):
|
scripted: torch.jit.ScriptModule,
|
||||||
|
class_annotator: ClassAnnotator,
|
||||||
|
):
|
||||||
assert module.__class__.__name__ == scripted.original_name or (
|
assert module.__class__.__name__ == scripted.original_name or (
|
||||||
isinstance(module, torch.jit.RecursiveScriptModule) and module is
|
isinstance(module, torch.jit.RecursiveScriptModule) and module is scripted
|
||||||
scripted), "script module does not come from specified module"
|
), "script module does not come from specified module"
|
||||||
|
|
||||||
# Extract information on methods.
|
# Extract information on methods.
|
||||||
for method_name, scripted_method in scripted.__dict__.items():
|
for method_name, scripted_method in scripted.__dict__.items():
|
||||||
if not isinstance(scripted_method, torch.ScriptMethod):
|
if not isinstance(scripted_method, torch.ScriptMethod):
|
||||||
continue
|
continue
|
||||||
method = getattr(module, method_name)
|
method = getattr(module, method_name)
|
||||||
if hasattr(method, '_torch_mlir_export'):
|
if hasattr(method, "_torch_mlir_export"):
|
||||||
class_annotator.exportPath(scripted._c._type(), [method_name])
|
class_annotator.exportPath(scripted._c._type(), [method_name])
|
||||||
if hasattr(method, '_torch_mlir_arg_annotations'):
|
if hasattr(method, "_torch_mlir_arg_annotations"):
|
||||||
class_annotator.annotateArgs(
|
class_annotator.annotateArgs(
|
||||||
scripted._c._type(), [method_name],
|
scripted._c._type(), [method_name], method._torch_mlir_arg_annotations
|
||||||
method._torch_mlir_arg_annotations)
|
)
|
||||||
# Recurse.
|
# Recurse.
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
scripted_child = getattr(scripted, name)
|
scripted_child = getattr(scripted, name)
|
||||||
_recursively_extract_annotations(child, scripted_child,
|
_recursively_extract_annotations(child, scripted_child, class_annotator)
|
||||||
class_annotator)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_annotations(program: torch.nn.Module,
|
def extract_annotations(
|
||||||
|
program: torch.nn.Module,
|
||||||
scripted: torch.jit.ScriptModule,
|
scripted: torch.jit.ScriptModule,
|
||||||
class_annotator: ClassAnnotator):
|
class_annotator: ClassAnnotator,
|
||||||
|
):
|
||||||
"""Populate the ClassAnnotator with annotations extracted from `program`."""
|
"""Populate the ClassAnnotator with annotations extracted from `program`."""
|
||||||
class_annotator.exportNone(scripted._c._type())
|
class_annotator.exportNone(scripted._c._type())
|
||||||
_recursively_extract_annotations(program, scripted, class_annotator)
|
_recursively_extract_annotations(program, scripted, class_annotator)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch_mlir.compiler_utils import (
|
from torch_mlir.compiler_utils import (
|
||||||
run_pipeline_with_repro_report,
|
run_pipeline_with_repro_report,
|
||||||
OutputType,
|
OutputType,
|
||||||
lower_mlir_module
|
lower_mlir_module,
|
||||||
)
|
)
|
||||||
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||||
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
||||||
|
@ -105,8 +105,7 @@ class ExampleArgs:
|
||||||
self, for chaining.
|
self, for chaining.
|
||||||
"""
|
"""
|
||||||
assert method_name not in self._example_args
|
assert method_name not in self._example_args
|
||||||
self._example_args[method_name] = ExampleArgs._canonicalize_args(
|
self._example_args[method_name] = ExampleArgs._canonicalize_args(example_args)
|
||||||
example_args)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -129,10 +128,12 @@ class ExampleArgs:
|
||||||
example_args = [example_args]
|
example_args = [example_args]
|
||||||
for arg in example_args:
|
for arg in example_args:
|
||||||
if not isinstance(arg, (TensorPlaceholder, torch.Tensor)):
|
if not isinstance(arg, (TensorPlaceholder, torch.Tensor)):
|
||||||
raise Exception(f"Only Tensor's, TensorPlaceholder's, or sequences of "
|
raise Exception(
|
||||||
|
f"Only Tensor's, TensorPlaceholder's, or sequences of "
|
||||||
f"Tensor's and TensorPlaceholder's are supported as "
|
f"Tensor's and TensorPlaceholder's are supported as "
|
||||||
f"example args for method inputs. "
|
f"example args for method inputs. "
|
||||||
f"Got '{arg}'.")
|
f"Got '{arg}'."
|
||||||
|
)
|
||||||
return tuple(example_args)
|
return tuple(example_args)
|
||||||
|
|
||||||
def _get_methods(self):
|
def _get_methods(self):
|
||||||
|
@ -171,7 +172,8 @@ class ExampleArgs:
|
||||||
# "hopefully the trace works for different inputs"
|
# "hopefully the trace works for different inputs"
|
||||||
# bucket of concerns.
|
# bucket of concerns.
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`")
|
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`"
|
||||||
|
)
|
||||||
# For any dynamic dimensions, replace them with "7"
|
# For any dynamic dimensions, replace them with "7"
|
||||||
# arbitrarily. If a user is using dynamic dimensions with
|
# arbitrarily. If a user is using dynamic dimensions with
|
||||||
# tracing, they are walking on thin ice already -- assume
|
# tracing, they are walking on thin ice already -- assume
|
||||||
|
@ -182,7 +184,8 @@ class ExampleArgs:
|
||||||
example_args_for_trace.append(torch.tensor(1))
|
example_args_for_trace.append(torch.tensor(1))
|
||||||
else:
|
else:
|
||||||
example_args_for_trace.append(
|
example_args_for_trace.append(
|
||||||
torch.ones(*shape, dtype=arg.dtype))
|
torch.ones(*shape, dtype=arg.dtype)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(arg, torch.Tensor)
|
assert isinstance(arg, torch.Tensor)
|
||||||
example_args_for_trace.append(arg)
|
example_args_for_trace.append(arg)
|
||||||
|
@ -198,21 +201,33 @@ class ExampleArgs:
|
||||||
# ops in the backend contract, and move these lists somewhere deeper in the
|
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||||
# compiler where each backend can "own" its set of legal ops.
|
# compiler where each backend can "own" its set of legal ops.
|
||||||
BACKEND_LEGAL_OPS = {
|
BACKEND_LEGAL_OPS = {
|
||||||
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
|
OutputType.TOSA: [
|
||||||
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d','aten.adaptive_avg_pool2d', 'aten.unflatten.int'],
|
"aten.flatten.using_ints",
|
||||||
|
"aten.native_layer_norm",
|
||||||
|
"aten.linear",
|
||||||
|
],
|
||||||
|
OutputType.LINALG_ON_TENSORS: [
|
||||||
|
"aten.flatten.using_ints",
|
||||||
|
"aten.adaptive_avg_pool1d",
|
||||||
|
"aten.adaptive_avg_pool2d",
|
||||||
|
"aten.unflatten.int",
|
||||||
|
],
|
||||||
OutputType.STABLEHLO: [],
|
OutputType.STABLEHLO: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra_library.mlir"):
|
def _canon_extra_library(
|
||||||
|
extra_library, extra_library_file_name="custom_op_extra_library.mlir"
|
||||||
|
):
|
||||||
if len(extra_library) != 0:
|
if len(extra_library) != 0:
|
||||||
extra_library_dict = {}
|
extra_library_dict = {}
|
||||||
for library_func in extra_library:
|
for library_func in extra_library:
|
||||||
extra_library_dict[library_func.__name__] = library_func
|
extra_library_dict[library_func.__name__] = library_func
|
||||||
mlir_library = generate_library(extra_library_dict)
|
mlir_library = generate_library(extra_library_dict)
|
||||||
|
|
||||||
extra_library_file = \
|
extra_library_file = os.path.join(
|
||||||
os.path.join(tempfile.gettempdir(), extra_library_file_name)
|
tempfile.gettempdir(), extra_library_file_name
|
||||||
|
)
|
||||||
with open(extra_library_file, "w") as f:
|
with open(extra_library_file, "w") as f:
|
||||||
f.write(mlir_library)
|
f.write(mlir_library)
|
||||||
return extra_library_file
|
return extra_library_file
|
||||||
|
@ -220,7 +235,8 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def compile(model: torch.nn.Module,
|
def compile(
|
||||||
|
model: torch.nn.Module,
|
||||||
example_args: _example_args,
|
example_args: _example_args,
|
||||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||||
use_tracing: bool = False,
|
use_tracing: bool = False,
|
||||||
|
@ -229,7 +245,8 @@ def compile(model: torch.nn.Module,
|
||||||
extra_library: Iterable[Callable] = [],
|
extra_library: Iterable[Callable] = [],
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
use_make_fx: bool = False,
|
use_make_fx: bool = False,
|
||||||
enable_ir_printing: bool = False):
|
enable_ir_printing: bool = False,
|
||||||
|
):
|
||||||
"""Convert a PyTorch model to MLIR.
|
"""Convert a PyTorch model to MLIR.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -283,18 +300,18 @@ def compile(model: torch.nn.Module,
|
||||||
# See `BACKEND_LEGAL_OPS` for more details.
|
# See `BACKEND_LEGAL_OPS` for more details.
|
||||||
if backend_legal_ops is not None:
|
if backend_legal_ops is not None:
|
||||||
if output_type != OutputType.TORCH:
|
if output_type != OutputType.TORCH:
|
||||||
raise Exception("`backend_legal_ops` is only valid with the "
|
raise Exception(
|
||||||
"`torch` output type")
|
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
||||||
|
)
|
||||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||||
else:
|
else:
|
||||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||||
|
|
||||||
if use_make_fx:
|
if use_make_fx:
|
||||||
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"]
|
args = example_args._get_for_tracing(
|
||||||
model = make_fx(
|
use_tracing=True, ignore_traced_shapes=True
|
||||||
model,
|
)["forward"]
|
||||||
decomposition_table=_get_decomposition_table())(*args)
|
model = make_fx(model, decomposition_table=_get_decomposition_table())(*args)
|
||||||
|
|
||||||
|
|
||||||
# For FX-based models, automatically strip overloads.
|
# For FX-based models, automatically strip overloads.
|
||||||
if isinstance(model, torch.fx.GraphModule):
|
if isinstance(model, torch.fx.GraphModule):
|
||||||
|
@ -317,12 +334,12 @@ def compile(model: torch.nn.Module,
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Model does not have exported method '{method_name}', "
|
f"Model does not have exported method '{method_name}', "
|
||||||
f"requested in `example_args`. Consider adding "
|
f"requested in `example_args`. Consider adding "
|
||||||
f"`@torch.jit.export` to the method definition.")
|
f"`@torch.jit.export` to the method definition."
|
||||||
|
)
|
||||||
scripted = model
|
scripted = model
|
||||||
elif use_tracing:
|
elif use_tracing:
|
||||||
scripted = torch.jit.trace_module(
|
scripted = torch.jit.trace_module(
|
||||||
model,
|
model, example_args._get_for_tracing(use_tracing, ignore_traced_shapes)
|
||||||
example_args._get_for_tracing(use_tracing, ignore_traced_shapes)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Make sure that all the methods that the user requested get scripted.
|
# Make sure that all the methods that the user requested get scripted.
|
||||||
|
@ -338,8 +355,7 @@ def compile(model: torch.nn.Module,
|
||||||
annotation = [None] # `None` is always the annotation for "self".
|
annotation = [None] # `None` is always the annotation for "self".
|
||||||
for arg in example_args:
|
for arg in example_args:
|
||||||
annotation.append((arg.shape, arg.dtype, True))
|
annotation.append((arg.shape, arg.dtype, True))
|
||||||
class_annotator.annotateArgs(
|
class_annotator.annotateArgs(scripted._c._type(), [method_name], annotation)
|
||||||
scripted._c._type(), [method_name], annotation)
|
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
import_options = ImportOptions()
|
import_options = ImportOptions()
|
||||||
|
@ -350,20 +366,27 @@ def compile(model: torch.nn.Module,
|
||||||
# Import the TorchScript module to MLIR
|
# Import the TorchScript module to MLIR
|
||||||
mb.import_module(scripted._c, class_annotator, import_options)
|
mb.import_module(scripted._c, class_annotator, import_options)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"""
|
raise Exception(
|
||||||
|
f"""
|
||||||
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||||
### Importer C++ Exception:
|
### Importer C++ Exception:
|
||||||
{e}
|
{e}
|
||||||
### Importer Diagnostics:
|
### Importer Diagnostics:
|
||||||
{sys.stderr.getvalue()}
|
{sys.stderr.getvalue()}
|
||||||
""") from None
|
"""
|
||||||
|
) from None
|
||||||
finally:
|
finally:
|
||||||
sys.stderr = original_stderr
|
sys.stderr = original_stderr
|
||||||
if output_type == OutputType.RAW:
|
if output_type == OutputType.RAW:
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
|
option_string = (
|
||||||
" extra-library=" + extra_library_file_name + "}"
|
"{backend-legal-ops="
|
||||||
|
+ ",".join(backend_legal_ops)
|
||||||
|
+ " extra-library="
|
||||||
|
+ extra_library_file_name
|
||||||
|
+ "}"
|
||||||
|
)
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
mb.module,
|
mb.module,
|
||||||
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
||||||
|
|
|
@ -22,8 +22,8 @@ import torch
|
||||||
# Attribute names used for annotations.
|
# Attribute names used for annotations.
|
||||||
# These should be kept in sync with their use in
|
# These should be kept in sync with their use in
|
||||||
# `torch_mlir/torchscript_annotations.py`.
|
# `torch_mlir/torchscript_annotations.py`.
|
||||||
TORCH_MLIR_EXPORT_ATTR_NAME = '_torch_mlir_export'
|
TORCH_MLIR_EXPORT_ATTR_NAME = "_torch_mlir_export"
|
||||||
TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = '_torch_mlir_arg_annotations'
|
TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = "_torch_mlir_arg_annotations"
|
||||||
|
|
||||||
|
|
||||||
def export(fn):
|
def export(fn):
|
||||||
|
|
|
@ -55,14 +55,20 @@ def jit(
|
||||||
output_type = OutputType.get(output_type)
|
output_type = OutputType.get(output_type)
|
||||||
if backend_legal_ops is not None:
|
if backend_legal_ops is not None:
|
||||||
if output_type != OutputType.TORCH:
|
if output_type != OutputType.TORCH:
|
||||||
raise Exception("`backend_legal_ops` is only valid with the "
|
raise Exception(
|
||||||
"`torch` output type")
|
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
||||||
|
)
|
||||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||||
else:
|
else:
|
||||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||||
|
|
||||||
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
|
option_string = (
|
||||||
" extra-library=" + extra_library_file_name + "}")
|
"{backend-legal-ops="
|
||||||
|
+ ",".join(backend_legal_ops)
|
||||||
|
+ " extra-library="
|
||||||
|
+ extra_library_file_name
|
||||||
|
+ "}"
|
||||||
|
)
|
||||||
|
|
||||||
mlir_module = fx.export_and_import(prog, func_name=func_name)
|
mlir_module = fx.export_and_import(prog, func_name=func_name)
|
||||||
assert mlir_module is not None
|
assert mlir_module is not None
|
||||||
|
@ -95,9 +101,11 @@ class FxImporterTestConfig(TestConfig):
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
for item in trace:
|
for item in trace:
|
||||||
prog = torch.export.export(artifact, tuple(item.inputs))
|
prog = torch.export.export(artifact, tuple(item.inputs))
|
||||||
module = jit(prog,
|
module = jit(
|
||||||
|
prog,
|
||||||
func_name=artifact.__class__.__name__,
|
func_name=artifact.__class__.__name__,
|
||||||
output_type=self._output_type)
|
output_type=self._output_type,
|
||||||
|
)
|
||||||
module = self._backend.compile(module)
|
module = self._backend.compile(module)
|
||||||
backend_module = self._backend.load(module)
|
backend_module = self._backend.load(module)
|
||||||
params = {
|
params = {
|
||||||
|
@ -107,10 +115,10 @@ class FxImporterTestConfig(TestConfig):
|
||||||
params_flat, params_spec = pytree.tree_flatten(params)
|
params_flat, params_spec = pytree.tree_flatten(params)
|
||||||
params_flat = list(params_flat)
|
params_flat = list(params_flat)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
numpy_inputs = recursively_convert_to_numpy(params_flat +
|
numpy_inputs = recursively_convert_to_numpy(params_flat + item.inputs)
|
||||||
item.inputs)
|
outputs = getattr(backend_module, artifact.__class__.__name__)(
|
||||||
outputs = getattr(backend_module,
|
*numpy_inputs
|
||||||
artifact.__class__.__name__)(*numpy_inputs)
|
)
|
||||||
output = refine_result_type(outputs)
|
output = refine_result_type(outputs)
|
||||||
if isinstance(output, (tuple, list)):
|
if isinstance(output, (tuple, list)):
|
||||||
user_output = []
|
user_output = []
|
||||||
|
@ -120,7 +128,6 @@ class FxImporterTestConfig(TestConfig):
|
||||||
user_output.append(val)
|
user_output.append(val)
|
||||||
output = tuple(user_output)
|
output = tuple(user_output)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -23,20 +23,19 @@ class LazyTensorCoreTestConfig(TestConfig):
|
||||||
lazy_backend._initialize()
|
lazy_backend._initialize()
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||||
return program.to('lazy')
|
return program.to("lazy")
|
||||||
|
|
||||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
|
|
||||||
for item in trace:
|
for item in trace:
|
||||||
# We need to move all the inputs to the lazy device before running in LTC.
|
# We need to move all the inputs to the lazy device before running in LTC.
|
||||||
lazy_inputs = tree_map(to_device('lazy'), item.inputs)
|
lazy_inputs = tree_map(to_device("lazy"), item.inputs)
|
||||||
output = getattr(artifact, item.symbol)(*lazy_inputs)
|
output = getattr(artifact, item.symbol)(*lazy_inputs)
|
||||||
cpu_outputs = tree_map(to_device('cpu'), output)
|
cpu_outputs = tree_map(to_device("cpu"), output)
|
||||||
|
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=cpu_outputs)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=cpu_outputs))
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -24,6 +24,7 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||||
This class handles all the common lowering that torch-mlir does before
|
This class handles all the common lowering that torch-mlir does before
|
||||||
reaching the linalg-on-tensors abstraction level.
|
reaching the linalg-on-tensors abstraction level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend: LinalgOnTensorsBackend):
|
def __init__(self, backend: LinalgOnTensorsBackend):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
@ -31,12 +32,11 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
example_args = convert_annotations_to_placeholders(program.forward)
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
module = torchscript.compile(
|
module = torchscript.compile(
|
||||||
program, example_args, output_type="linalg-on-tensors")
|
program, example_args, output_type="linalg-on-tensors"
|
||||||
|
)
|
||||||
|
|
||||||
return self.backend.compile(module)
|
return self.backend.compile(module)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
backend_module = self.backend.load(artifact)
|
backend_module = self.backend.load(artifact)
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
|
@ -45,7 +45,6 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||||
output = recursively_convert_from_numpy(outputs)
|
output = recursively_convert_from_numpy(outputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -10,6 +10,7 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
class NativeTorchTestConfig(TestConfig):
|
class NativeTorchTestConfig(TestConfig):
|
||||||
"""TestConfig that just runs the torch.nn.Module without compiling"""
|
"""TestConfig that just runs the torch.nn.Module without compiling"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -23,7 +24,6 @@ class NativeTorchTestConfig(TestConfig):
|
||||||
for item in trace:
|
for item in trace:
|
||||||
output = getattr(artifact, item.symbol)(*item.inputs)
|
output = getattr(artifact, item.symbol)(*item.inputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -47,7 +47,7 @@ def convert_onnx(model, inputs):
|
||||||
input_names = []
|
input_names = []
|
||||||
dynamic_tensors = {}
|
dynamic_tensors = {}
|
||||||
for (index, arg) in enumerate(inputs):
|
for (index, arg) in enumerate(inputs):
|
||||||
shape = map(lambda d : d if d >= 0 else 1, arg.shape)
|
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
||||||
shape = tuple(shape)
|
shape = tuple(shape)
|
||||||
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
||||||
|
|
||||||
|
@ -56,24 +56,27 @@ def convert_onnx(model, inputs):
|
||||||
|
|
||||||
dynamic_dims = {}
|
dynamic_dims = {}
|
||||||
for (dimindex, dim) in enumerate(arg.shape):
|
for (dimindex, dim) in enumerate(arg.shape):
|
||||||
if (dim < 0):
|
if dim < 0:
|
||||||
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
||||||
|
|
||||||
if (dynamic_dims):
|
if dynamic_dims:
|
||||||
dynamic_tensors[input_name] = dynamic_dims
|
dynamic_tensors[input_name] = dynamic_dims
|
||||||
|
|
||||||
|
examples = tuple(examples)
|
||||||
examples=tuple(examples)
|
torch.onnx.export(
|
||||||
torch.onnx.export(model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors)
|
model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors
|
||||||
|
)
|
||||||
buffer = buffer.getvalue()
|
buffer = buffer.getvalue()
|
||||||
return import_onnx(buffer)
|
return import_onnx(buffer)
|
||||||
|
|
||||||
|
|
||||||
class OnnxBackendTestConfig(TestConfig):
|
class OnnxBackendTestConfig(TestConfig):
|
||||||
"""Base class for TestConfig's that are implemented with ONNX.
|
"""Base class for TestConfig's that are implemented with ONNX.
|
||||||
|
|
||||||
This class handles all the common lowering that torch-mlir does before
|
This class handles all the common lowering that torch-mlir does before
|
||||||
reaching the ONNX abstraction level.
|
reaching the ONNX abstraction level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
|
def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
@ -85,8 +88,6 @@ class OnnxBackendTestConfig(TestConfig):
|
||||||
compiled_module = self.backend.compile(onnx_module)
|
compiled_module = self.backend.compile(onnx_module)
|
||||||
return compiled_module
|
return compiled_module
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
backend_module = self.backend.load(artifact)
|
backend_module = self.backend.load(artifact)
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
|
@ -95,7 +96,6 @@ class OnnxBackendTestConfig(TestConfig):
|
||||||
outputs = getattr(backend_module, "main_graph")(*numpy_inputs)
|
outputs = getattr(backend_module, "main_graph")(*numpy_inputs)
|
||||||
output = recursively_convert_from_numpy(outputs)
|
output = recursively_convert_from_numpy(outputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -48,13 +48,13 @@ def refine_result_type(_result):
|
||||||
def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
|
def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
|
||||||
for node in fx_graph.graph.nodes:
|
for node in fx_graph.graph.nodes:
|
||||||
if node.op == "output":
|
if node.op == "output":
|
||||||
assert len(
|
assert len(node.args) == 1, "Output node must have a single argument"
|
||||||
node.args) == 1, "Output node must have a single argument"
|
|
||||||
node_arg = node.args[0]
|
node_arg = node.args[0]
|
||||||
if node_arg != ():
|
if node_arg != ():
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to
|
# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to
|
||||||
# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument
|
# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument
|
||||||
# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it
|
# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it
|
||||||
|
@ -69,7 +69,7 @@ def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule):
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
# Checks if we're calling a function (i.e:
|
# Checks if we're calling a function (i.e:
|
||||||
# torch.add)
|
# torch.add)
|
||||||
if node.op == 'call_function':
|
if node.op == "call_function":
|
||||||
# The target attribute is the function
|
# The target attribute is the function
|
||||||
# that call_function calls.
|
# that call_function calls.
|
||||||
# call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {})
|
# call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {})
|
||||||
|
@ -108,20 +108,24 @@ def jit(
|
||||||
output_type = OutputType.get(output_type)
|
output_type = OutputType.get(output_type)
|
||||||
if backend_legal_ops is not None:
|
if backend_legal_ops is not None:
|
||||||
if output_type != OutputType.TORCH:
|
if output_type != OutputType.TORCH:
|
||||||
raise Exception("`backend_legal_ops` is only valid with the "
|
raise Exception(
|
||||||
"`torch` output type")
|
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
||||||
|
)
|
||||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||||
else:
|
else:
|
||||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||||
|
|
||||||
@make_boxed_compiler
|
@make_boxed_compiler
|
||||||
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
|
def my_aot_autograd_backend(
|
||||||
_example_inputs: List[torch.Tensor]):
|
gm: torch.fx.GraphModule, _example_inputs: List[torch.Tensor]
|
||||||
|
):
|
||||||
# Torch-MLIR does not support returning an empty tuple. The reason is
|
# Torch-MLIR does not support returning an empty tuple. The reason is
|
||||||
# that both returning an empty tuple and returning `None` results in MLIR
|
# that both returning an empty tuple and returning `None` results in MLIR
|
||||||
# functions that have as a return type `()`. In other words, there is no
|
# functions that have as a return type `()`. In other words, there is no
|
||||||
# way of differentiating between the two.
|
# way of differentiating between the two.
|
||||||
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
|
assert not _returns_empty_tuple(
|
||||||
|
gm
|
||||||
|
), "encountered graph that does not return anything"
|
||||||
|
|
||||||
scalarize_tensor_ops_on_scalars(gm)
|
scalarize_tensor_ops_on_scalars(gm)
|
||||||
|
|
||||||
|
@ -130,18 +134,24 @@ def jit(
|
||||||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend,
|
my_backend = aot_autograd(
|
||||||
decompositions=_get_decomposition_table)
|
fw_compiler=my_aot_autograd_backend, decompositions=_get_decomposition_table
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
set_model_name(model.__class__.__name__)
|
set_model_name(model.__class__.__name__)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
dynamo_f = dynamo.optimize(my_backend, nopython=True)(
|
dynamo_f = dynamo.optimize(my_backend, nopython=True)(
|
||||||
lambda method, *inputs: method(*inputs))
|
lambda method, *inputs: method(*inputs)
|
||||||
dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]),
|
)
|
||||||
*example_args)
|
dynamo_f(lambda *inputs: model(*[x.clone() for x in inputs]), *example_args)
|
||||||
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
|
option_string = (
|
||||||
" extra-library=" + extra_library_file_name + "}")
|
"{backend-legal-ops="
|
||||||
|
+ ",".join(backend_legal_ops)
|
||||||
|
+ " extra-library="
|
||||||
|
+ extra_library_file_name
|
||||||
|
+ "}"
|
||||||
|
)
|
||||||
assert mlir_module is not None
|
assert mlir_module is not None
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
mlir_module,
|
mlir_module,
|
||||||
|
@ -166,9 +176,7 @@ class TorchDynamoTestConfig(TestConfig):
|
||||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
for item in trace:
|
for item in trace:
|
||||||
module = jit(artifact,
|
module = jit(artifact, item.inputs, output_type="linalg-on-tensors")
|
||||||
item.inputs,
|
|
||||||
output_type="linalg-on-tensors")
|
|
||||||
module = self.backend.compile(module)
|
module = self.backend.compile(module)
|
||||||
backend_module = self.backend.load(module)
|
backend_module = self.backend.load(module)
|
||||||
params = {
|
params = {
|
||||||
|
@ -178,13 +186,12 @@ class TorchDynamoTestConfig(TestConfig):
|
||||||
params_flat, params_spec = pytree.tree_flatten(params)
|
params_flat, params_spec = pytree.tree_flatten(params)
|
||||||
params_flat = list(params_flat)
|
params_flat = list(params_flat)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
numpy_inputs = recursively_convert_to_numpy(params_flat +
|
numpy_inputs = recursively_convert_to_numpy(params_flat + item.inputs)
|
||||||
item.inputs)
|
outputs = getattr(backend_module, artifact.__class__.__name__)(
|
||||||
outputs = getattr(backend_module,
|
*numpy_inputs
|
||||||
artifact.__class__.__name__)(*numpy_inputs)
|
)
|
||||||
output = refine_result_type(outputs)
|
output = refine_result_type(outputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
class TorchScriptTestConfig(TestConfig):
|
class TorchScriptTestConfig(TestConfig):
|
||||||
"""TestConfig that runs the torch.nn.Module through TorchScript"""
|
"""TestConfig that runs the torch.nn.Module through TorchScript"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -26,11 +27,10 @@ class TorchScriptTestConfig(TestConfig):
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
for item in trace:
|
for item in trace:
|
||||||
attr = artifact
|
attr = artifact
|
||||||
for part in item.symbol.split('.'):
|
for part in item.symbol.split("."):
|
||||||
attr = getattr(attr, part)
|
attr = getattr(attr, part)
|
||||||
output = attr(*item.inputs)
|
output = attr(*item.inputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -23,6 +23,7 @@ class TosaBackendTestConfig(TestConfig):
|
||||||
This class handles all the common lowering that torch-mlir does before
|
This class handles all the common lowering that torch-mlir does before
|
||||||
reaching the TOSA abstraction level.
|
reaching the TOSA abstraction level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend: TosaBackend, use_make_fx: bool = False):
|
def __init__(self, backend: TosaBackend, use_make_fx: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
@ -31,12 +32,11 @@ class TosaBackendTestConfig(TestConfig):
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
example_args = convert_annotations_to_placeholders(program.forward)
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
module = torchscript.compile(
|
module = torchscript.compile(
|
||||||
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
|
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx
|
||||||
|
)
|
||||||
|
|
||||||
return self.backend.compile(module)
|
return self.backend.compile(module)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
backend_module = self.backend.load(artifact)
|
backend_module = self.backend.load(artifact)
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
|
@ -45,7 +45,6 @@ class TosaBackendTestConfig(TestConfig):
|
||||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||||
output = recursively_convert_from_numpy(outputs)
|
output = recursively_convert_from_numpy(outputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -27,6 +27,7 @@ def recursively_convert_to_numpy(o: Any):
|
||||||
return o
|
return o
|
||||||
raise Exception(f"Unexpected Python function input: {o}")
|
raise Exception(f"Unexpected Python function input: {o}")
|
||||||
|
|
||||||
|
|
||||||
def recursively_convert_from_numpy(o: Any):
|
def recursively_convert_from_numpy(o: Any):
|
||||||
if isinstance(o, np.ndarray):
|
if isinstance(o, np.ndarray):
|
||||||
return torch.from_numpy(o)
|
return torch.from_numpy(o)
|
||||||
|
|
|
@ -24,8 +24,7 @@ def _make_single_op_gm(node) -> torch.fx.GraphModule:
|
||||||
return torch.fx.GraphModule(torch.nn.Module(), g)
|
return torch.fx.GraphModule(torch.nn.Module(), g)
|
||||||
|
|
||||||
|
|
||||||
def _identity_backend(gm: torch.fx.GraphModule,
|
def _identity_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||||
example_inputs: List[torch.Tensor]):
|
|
||||||
"""A backend that just runs the given GraphModule as-is."""
|
"""A backend that just runs the given GraphModule as-is."""
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
@ -51,6 +50,7 @@ def _make_last_use_map(g: torch.fx.Graph) -> Dict[torch.fx.Node, List[torch.fx.N
|
||||||
# Lifetime just ended, so this is the last use.
|
# Lifetime just ended, so this is the last use.
|
||||||
seen.add(use)
|
seen.add(use)
|
||||||
last_use_map[user].append(use)
|
last_use_map[user].append(use)
|
||||||
|
|
||||||
for node in reversed(g.nodes):
|
for node in reversed(g.nodes):
|
||||||
assert not node.kwargs, "kwargs not supported yet"
|
assert not node.kwargs, "kwargs not supported yet"
|
||||||
torch.fx.map_arg(node.args, lambda n: process_use(node, n))
|
torch.fx.map_arg(node.args, lambda n: process_use(node, n))
|
||||||
|
@ -84,9 +84,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
||||||
Returns:
|
Returns:
|
||||||
A backend that compares the wrapped backend to `golden_backend`.
|
A backend that compares the wrapped backend to `golden_backend`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(user_backend):
|
def wrapper(user_backend):
|
||||||
def backend(gm: torch.fx.GraphModule,
|
def backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||||
example_inputs: List[torch.Tensor]):
|
|
||||||
# We can ignore the example_inputs since we recompile in lockstep
|
# We can ignore the example_inputs since we recompile in lockstep
|
||||||
# anyway. TorchDynamo should already have appropriate guards in
|
# anyway. TorchDynamo should already have appropriate guards in
|
||||||
# place so that this doesn't change the compilation result.
|
# place so that this doesn't change the compilation result.
|
||||||
|
@ -96,7 +96,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
||||||
|
|
||||||
def compiled(*args):
|
def compiled(*args):
|
||||||
env = {}
|
env = {}
|
||||||
for placeholder, arg in zip([n for n in g.nodes if n.op == "placeholder"], args):
|
for placeholder, arg in zip(
|
||||||
|
[n for n in g.nodes if n.op == "placeholder"], args
|
||||||
|
):
|
||||||
env[placeholder] = arg
|
env[placeholder] = arg
|
||||||
# Evaluate the graph one node at a time, comparing the user and
|
# Evaluate the graph one node at a time, comparing the user and
|
||||||
# golden backends. This code currently does not support
|
# golden backends. This code currently does not support
|
||||||
|
@ -111,7 +113,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
||||||
continue
|
continue
|
||||||
if node.op == "output":
|
if node.op == "output":
|
||||||
return torch.fx.map_arg(node.args[0], lambda n: env[n])
|
return torch.fx.map_arg(node.args[0], lambda n: env[n])
|
||||||
assert node.op == "call_function", f"call_module/call_method not supported for {node} -- perhaps call make_simple_dynamo_backend first"
|
assert (
|
||||||
|
node.op == "call_function"
|
||||||
|
), f"call_module/call_method not supported for {node} -- perhaps call make_simple_dynamo_backend first"
|
||||||
assert not node.kwargs, "kwargs not supported yet"
|
assert not node.kwargs, "kwargs not supported yet"
|
||||||
actual_args = torch.fx.map_arg(node.args, lambda n: env[n])
|
actual_args = torch.fx.map_arg(node.args, lambda n: env[n])
|
||||||
if node not in backend_artifacts:
|
if node not in backend_artifacts:
|
||||||
|
@ -128,7 +132,8 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
||||||
assert torch.allclose(user_result, golden_result), (
|
assert torch.allclose(user_result, golden_result), (
|
||||||
f"User result {user_result} is not close to "
|
f"User result {user_result} is not close to "
|
||||||
f"golden result {golden_result} for "
|
f"golden result {golden_result} for "
|
||||||
f"node {node} at {node.stack_trace}")
|
f"node {node} at {node.stack_trace}"
|
||||||
|
)
|
||||||
# Clean up any tensors that are no longer needed.
|
# Clean up any tensors that are no longer needed.
|
||||||
# TODO: Find a way to test this.
|
# TODO: Find a way to test this.
|
||||||
# This was tested manually by printing the number of entries
|
# This was tested manually by printing the number of entries
|
||||||
|
@ -136,6 +141,9 @@ def make_lockstep_debug_backend(golden_backend=_identity_backend):
|
||||||
for dead_node in last_use_map[node]:
|
for dead_node in last_use_map[node]:
|
||||||
env.pop(dead_node)
|
env.pop(dead_node)
|
||||||
assert False, "not reached -- missing 'output' node"
|
assert False, "not reached -- missing 'output' node"
|
||||||
|
|
||||||
return compiled
|
return compiled
|
||||||
|
|
||||||
return backend
|
return backend
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
@ -30,6 +30,7 @@ import traceback
|
||||||
|
|
||||||
import multiprocess as mp
|
import multiprocess as mp
|
||||||
from multiprocess import set_start_method
|
from multiprocess import set_start_method
|
||||||
|
|
||||||
try:
|
try:
|
||||||
set_start_method("spawn")
|
set_start_method("spawn")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
@ -38,9 +39,13 @@ except RuntimeError:
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
|
TorchScriptValue = Union[
|
||||||
Dict['TorchScriptValue',
|
int,
|
||||||
'TorchScriptValue'], torch.Tensor]
|
float,
|
||||||
|
List["TorchScriptValue"],
|
||||||
|
Dict["TorchScriptValue", "TorchScriptValue"],
|
||||||
|
torch.Tensor,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TraceItem(NamedTuple):
|
class TraceItem(NamedTuple):
|
||||||
|
@ -91,9 +96,11 @@ def clone_torch_script_value(v: TorchScriptValue):
|
||||||
# TODO: Figure out the root cause of the failure and fix properly.
|
# TODO: Figure out the root cause of the failure and fix properly.
|
||||||
def clone_trace(trace: Trace) -> Trace:
|
def clone_trace(trace: Trace) -> Trace:
|
||||||
return [
|
return [
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(
|
||||||
|
symbol=item.symbol,
|
||||||
inputs=clone_torch_script_value(item.inputs),
|
inputs=clone_torch_script_value(item.inputs),
|
||||||
output=clone_torch_script_value(item.output))
|
output=clone_torch_script_value(item.output),
|
||||||
|
)
|
||||||
for item in trace
|
for item in trace
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -101,7 +108,8 @@ def clone_trace(trace: Trace) -> Trace:
|
||||||
# A type shared between the result of `TestConfig.compile` and the input
|
# A type shared between the result of `TestConfig.compile` and the input
|
||||||
# to `TestConfig.run`. Each backend will likely have a different definition of
|
# to `TestConfig.run`. Each backend will likely have a different definition of
|
||||||
# this type.
|
# this type.
|
||||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||||
|
|
||||||
|
|
||||||
class TestConfig(abc.ABC):
|
class TestConfig(abc.ABC):
|
||||||
"""The interface implemented by backends to run tests.
|
"""The interface implemented by backends to run tests.
|
||||||
|
@ -136,6 +144,7 @@ class TestConfig(abc.ABC):
|
||||||
backend (compiler backend and runtime target) will have an arbitrarily
|
backend (compiler backend and runtime target) will have an arbitrarily
|
||||||
wild and wonderful set of possible configurations that we cannot predict.
|
wild and wonderful set of possible configurations that we cannot predict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
|
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
|
||||||
# We can have a helper class LinalgOnTensorsBackendTestConfig which does that.
|
# We can have a helper class LinalgOnTensorsBackendTestConfig which does that.
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -202,8 +211,8 @@ class TestUtils:
|
||||||
|
|
||||||
|
|
||||||
class Test(NamedTuple):
|
class Test(NamedTuple):
|
||||||
"""A description of a test as produced by the test frontend.
|
"""A description of a test as produced by the test frontend."""
|
||||||
"""
|
|
||||||
# Stable name for error reporting.
|
# Stable name for error reporting.
|
||||||
#
|
#
|
||||||
# This name's stability is also useful for backend, which want to
|
# This name's stability is also useful for backend, which want to
|
||||||
|
@ -268,14 +277,20 @@ class _Tracer:
|
||||||
inputs = [clone_torch_script_value(arg) for arg in args]
|
inputs = [clone_torch_script_value(arg) for arg in args]
|
||||||
output = self.__wrapped__(*args, **kwargs)
|
output = self.__wrapped__(*args, **kwargs)
|
||||||
self.__trace__.append(
|
self.__trace__.append(
|
||||||
TraceItem(symbol=".".join(self.__property_base_path__),
|
TraceItem(
|
||||||
|
symbol=".".join(self.__property_base_path__),
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
output=output))
|
output=output,
|
||||||
|
)
|
||||||
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return _Tracer(getattr(self.__wrapped__, name),
|
return _Tracer(
|
||||||
self.__property_base_path__ + [name], self.__trace__)
|
getattr(self.__wrapped__, name),
|
||||||
|
self.__property_base_path__ + [name],
|
||||||
|
self.__trace__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_golden_trace(test: Test) -> Trace:
|
def generate_golden_trace(test: Test) -> Trace:
|
||||||
|
@ -297,40 +312,49 @@ def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any:
|
||||||
print(f"Compiling {test.unique_name}...", file=sys.stderr)
|
print(f"Compiling {test.unique_name}...", file=sys.stderr)
|
||||||
compiled = config.compile(test.program_factory())
|
compiled = config.compile(test.program_factory())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return TestResult(unique_name=test.unique_name,
|
return TestResult(
|
||||||
|
unique_name=test.unique_name,
|
||||||
compilation_error="".join(
|
compilation_error="".join(
|
||||||
traceback.format_exception(
|
traceback.format_exception(type(e), e, e.__traceback__)
|
||||||
type(e), e, e.__traceback__)),
|
),
|
||||||
runtime_error=None,
|
runtime_error=None,
|
||||||
trace=None,
|
trace=None,
|
||||||
golden_trace=None)
|
golden_trace=None,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Running {test.unique_name}...", file=sys.stderr)
|
print(f"Running {test.unique_name}...", file=sys.stderr)
|
||||||
trace = config.run(compiled, golden_trace)
|
trace = config.run(compiled, golden_trace)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return TestResult(unique_name=test.unique_name,
|
return TestResult(
|
||||||
|
unique_name=test.unique_name,
|
||||||
compilation_error=None,
|
compilation_error=None,
|
||||||
runtime_error="".join(
|
runtime_error="".join(
|
||||||
traceback.format_exception(
|
traceback.format_exception(type(e), e, e.__traceback__)
|
||||||
type(e), e, e.__traceback__)),
|
),
|
||||||
trace=None,
|
trace=None,
|
||||||
golden_trace=None)
|
golden_trace=None,
|
||||||
return TestResult(unique_name=test.unique_name,
|
)
|
||||||
|
return TestResult(
|
||||||
|
unique_name=test.unique_name,
|
||||||
compilation_error=None,
|
compilation_error=None,
|
||||||
runtime_error=None,
|
runtime_error=None,
|
||||||
trace=clone_trace(trace),
|
trace=clone_trace(trace),
|
||||||
golden_trace=clone_trace(golden_trace))
|
golden_trace=clone_trace(golden_trace),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]:
|
def run_tests(
|
||||||
|
tests: List[Test], config: TestConfig, sequential=False, verbose=False
|
||||||
|
) -> List[TestResult]:
|
||||||
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
||||||
num_processes = min(int(mp.cpu_count() * 0.8) + 1, len(tests))
|
num_processes = min(int(mp.cpu_count() * 0.8) + 1, len(tests))
|
||||||
try:
|
try:
|
||||||
env_concurrency = int(os.getenv("TORCH_MLIR_TEST_CONCURRENCY", "0"))
|
env_concurrency = int(os.getenv("TORCH_MLIR_TEST_CONCURRENCY", "0"))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError("Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: "
|
raise ValueError(
|
||||||
"Expected integer.") from e
|
"Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: " "Expected integer."
|
||||||
|
) from e
|
||||||
if env_concurrency > 0:
|
if env_concurrency > 0:
|
||||||
num_processes = min(num_processes, env_concurrency)
|
num_processes = min(num_processes, env_concurrency)
|
||||||
|
|
||||||
|
@ -374,10 +398,11 @@ def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=F
|
||||||
TestResult(
|
TestResult(
|
||||||
unique_name=aborted_test_name,
|
unique_name=aborted_test_name,
|
||||||
compilation_error=None,
|
compilation_error=None,
|
||||||
runtime_error=
|
runtime_error="Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
|
||||||
"Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
|
|
||||||
trace=None,
|
trace=None,
|
||||||
golden_trace=None) for aborted_test_name in aborted_tests
|
golden_trace=None,
|
||||||
|
)
|
||||||
|
for aborted_test_name in aborted_tests
|
||||||
]
|
]
|
||||||
results.extend(aborted_tests_results)
|
results.extend(aborted_tests_results)
|
||||||
results.sort(key=lambda result: result.unique_name)
|
results.sort(key=lambda result: result.unique_name)
|
||||||
|
|
|
@ -13,12 +13,12 @@ from torch_mlir.ir import Module
|
||||||
# A type shared between the result of `LinalgOnTensorsBackend.compile` and the
|
# A type shared between the result of `LinalgOnTensorsBackend.compile` and the
|
||||||
# input to `LinalgOnTensorsBackend.load`. Each backend will likely have a
|
# input to `LinalgOnTensorsBackend.load`. Each backend will likely have a
|
||||||
# different definition of this type.
|
# different definition of this type.
|
||||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||||
|
|
||||||
# A wrapper around a backend-specific loaded program representation
|
# A wrapper around a backend-specific loaded program representation
|
||||||
# that uniformly translates the `x.method(...)` interface expected of
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
# Torch modules into appropriate lower-level operations.
|
# Torch modules into appropriate lower-level operations.
|
||||||
Invoker = TypeVar('Invoker')
|
Invoker = TypeVar("Invoker")
|
||||||
|
|
||||||
|
|
||||||
class LinalgOnTensorsBackend(abc.ABC):
|
class LinalgOnTensorsBackend(abc.ABC):
|
||||||
|
@ -27,6 +27,7 @@ class LinalgOnTensorsBackend(abc.ABC):
|
||||||
Backends are recommended to raise meaningful exceptions in case of error,
|
Backends are recommended to raise meaningful exceptions in case of error,
|
||||||
ideally with easy reproduction instructions.
|
ideally with easy reproduction instructions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compile(self, module: Module) -> CompiledArtifact:
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
"""Compile the provided MLIR module into a compiled artifact.
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
|
@ -22,10 +22,20 @@ __all__ = [
|
||||||
|
|
||||||
def assert_arg_type_is_supported(ty):
|
def assert_arg_type_is_supported(ty):
|
||||||
SUPPORTED = [
|
SUPPORTED = [
|
||||||
np.float16, np.float32, np.float64, np.uint8, np.int8, np.int32,
|
np.float16,
|
||||||
np.int64, np.bool_, np.complex64, np.complex128
|
np.float32,
|
||||||
|
np.float64,
|
||||||
|
np.uint8,
|
||||||
|
np.int8,
|
||||||
|
np.int32,
|
||||||
|
np.int64,
|
||||||
|
np.bool_,
|
||||||
|
np.complex64,
|
||||||
|
np.complex128,
|
||||||
]
|
]
|
||||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
|
assert (
|
||||||
|
ty in SUPPORTED
|
||||||
|
), f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
|
||||||
|
|
||||||
|
|
||||||
memref_type_to_np_dtype = {
|
memref_type_to_np_dtype = {
|
||||||
|
@ -37,14 +47,14 @@ memref_type_to_np_dtype = {
|
||||||
"mri32": np.int32,
|
"mri32": np.int32,
|
||||||
"mri64": np.int64,
|
"mri64": np.int64,
|
||||||
"mrc32": np.complex64,
|
"mrc32": np.complex64,
|
||||||
"mrc64": np.complex128
|
"mrc64": np.complex128,
|
||||||
}
|
}
|
||||||
elemental_type_to_ctype = {
|
elemental_type_to_ctype = {
|
||||||
"i1": ctypes.c_bool,
|
"i1": ctypes.c_bool,
|
||||||
"i8": ctypes.c_byte,
|
"i8": ctypes.c_byte,
|
||||||
"i64": ctypes.c_int,
|
"i64": ctypes.c_int,
|
||||||
"f32": ctypes.c_float,
|
"f32": ctypes.c_float,
|
||||||
"f64": ctypes.c_double
|
"f64": ctypes.c_double,
|
||||||
}
|
}
|
||||||
|
|
||||||
CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_"
|
CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_"
|
||||||
|
@ -56,7 +66,7 @@ def get_return_funcs(module):
|
||||||
with module.context:
|
with module.context:
|
||||||
for func in module.body:
|
for func in module.body:
|
||||||
# Returns strings of the form `"refbackend.."` so `"` is deleted.
|
# Returns strings of the form `"refbackend.."` so `"` is deleted.
|
||||||
func_name = str(func.attributes["sym_name"]).replace('"', '')
|
func_name = str(func.attributes["sym_name"]).replace('"', "")
|
||||||
if func_name[:return_prefix_len] == CONSUME_RETURN_FUNC_PREFIX:
|
if func_name[:return_prefix_len] == CONSUME_RETURN_FUNC_PREFIX:
|
||||||
return_funcs.append(func_name)
|
return_funcs.append(func_name)
|
||||||
|
|
||||||
|
@ -79,7 +89,6 @@ def get_ctype_func(func_name):
|
||||||
|
|
||||||
|
|
||||||
class RefBackendInvoker:
|
class RefBackendInvoker:
|
||||||
|
|
||||||
def __init__(self, module):
|
def __init__(self, module):
|
||||||
self.ee = ExecutionEngine(module)
|
self.ee = ExecutionEngine(module)
|
||||||
self.result = None
|
self.result = None
|
||||||
|
@ -90,27 +99,29 @@ class RefBackendInvoker:
|
||||||
ctype_wrapper, ret_types = get_ctype_func(ret_func)
|
ctype_wrapper, ret_types = get_ctype_func(ret_func)
|
||||||
|
|
||||||
def consume_return_funcs(*args):
|
def consume_return_funcs(*args):
|
||||||
self.result = tuple([
|
self.result = tuple(
|
||||||
arg if type in elemental_type_to_ctype
|
[
|
||||||
|
arg
|
||||||
|
if type in elemental_type_to_ctype
|
||||||
else unranked_memref_to_numpy(
|
else unranked_memref_to_numpy(
|
||||||
arg, memref_type_to_np_dtype[type])
|
arg, memref_type_to_np_dtype[type]
|
||||||
|
)
|
||||||
for arg, type in zip(args, ret_types)
|
for arg, type in zip(args, ret_types)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
if len(self.result) == 1:
|
if len(self.result) == 1:
|
||||||
self.result = self.result[0]
|
self.result = self.result[0]
|
||||||
|
|
||||||
self.ee.register_runtime(ret_func,
|
self.ee.register_runtime(ret_func, ctype_wrapper(consume_return_funcs))
|
||||||
ctype_wrapper(consume_return_funcs))
|
|
||||||
|
|
||||||
def __getattr__(self, function_name: str):
|
def __getattr__(self, function_name: str):
|
||||||
|
|
||||||
def invoke(*args):
|
def invoke(*args):
|
||||||
ffi_args = []
|
ffi_args = []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
assert_arg_type_is_supported(arg.dtype)
|
assert_arg_type_is_supported(arg.dtype)
|
||||||
ffi_args.append(
|
ffi_args.append(
|
||||||
ctypes.pointer(
|
ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(arg)))
|
||||||
ctypes.pointer(get_unranked_memref_descriptor(arg))))
|
)
|
||||||
|
|
||||||
self.ee.invoke(function_name, *ffi_args)
|
self.ee.invoke(function_name, *ffi_args)
|
||||||
result = self.result
|
result = self.result
|
||||||
|
@ -121,7 +132,10 @@ class RefBackendInvoker:
|
||||||
return invoke
|
return invoke
|
||||||
|
|
||||||
|
|
||||||
LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
LOWERING_PIPELINE = (
|
||||||
|
"builtin.module("
|
||||||
|
+ ",".join(
|
||||||
|
[
|
||||||
"func.func(refback-generalize-tensor-pad)",
|
"func.func(refback-generalize-tensor-pad)",
|
||||||
"func.func(refback-generalize-tensor-concat)",
|
"func.func(refback-generalize-tensor-concat)",
|
||||||
# Apply some optimizations. It would be great if MLIR had more useful
|
# Apply some optimizations. It would be great if MLIR had more useful
|
||||||
|
@ -181,7 +195,10 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
||||||
"convert-cf-to-llvm",
|
"convert-cf-to-llvm",
|
||||||
"convert-complex-to-llvm",
|
"convert-complex-to-llvm",
|
||||||
"reconcile-unrealized-casts",
|
"reconcile-unrealized-casts",
|
||||||
]) + ")"
|
]
|
||||||
|
)
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
||||||
|
@ -204,7 +221,8 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
||||||
passed to `load`.
|
passed to `load`.
|
||||||
"""
|
"""
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module, LOWERING_PIPELINE,
|
imported_module,
|
||||||
|
LOWERING_PIPELINE,
|
||||||
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",
|
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",
|
||||||
enable_ir_printing=False,
|
enable_ir_printing=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,12 +13,12 @@ from torch_mlir.ir import Module
|
||||||
# A type shared between the result of `OnnxBackend.compile` and the
|
# A type shared between the result of `OnnxBackend.compile` and the
|
||||||
# input to `OnnxBackend.load`. Each backend will likely have a
|
# input to `OnnxBackend.load`. Each backend will likely have a
|
||||||
# different definition of this type.
|
# different definition of this type.
|
||||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||||
|
|
||||||
# A wrapper around a backend-specific loaded program representation
|
# A wrapper around a backend-specific loaded program representation
|
||||||
# that uniformly translates the `x.method(...)` interface expected of
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
# Torch modules into appropriate lower-level operations.
|
# Torch modules into appropriate lower-level operations.
|
||||||
Invoker = TypeVar('Invoker')
|
Invoker = TypeVar("Invoker")
|
||||||
|
|
||||||
|
|
||||||
class OnnxBackend(abc.ABC):
|
class OnnxBackend(abc.ABC):
|
||||||
|
@ -27,6 +27,7 @@ class OnnxBackend(abc.ABC):
|
||||||
Backends are recommended to raise meaningful exceptions in case of error,
|
Backends are recommended to raise meaningful exceptions in case of error,
|
||||||
ideally with easy reproduction instructions.
|
ideally with easy reproduction instructions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compile(self, module: Module) -> CompiledArtifact:
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
"""Compile the provided MLIR module into a compiled artifact.
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
|
@ -12,7 +12,9 @@ from torch_mlir.compiler_utils import (
|
||||||
from torch_mlir.ir import *
|
from torch_mlir.ir import *
|
||||||
from torch_mlir.passmanager import *
|
from torch_mlir.passmanager import *
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||||
|
RefBackendLinalgOnTensorsBackend,
|
||||||
|
)
|
||||||
|
|
||||||
from .abc import OnnxBackend
|
from .abc import OnnxBackend
|
||||||
|
|
||||||
|
@ -22,9 +24,11 @@ __all__ = [
|
||||||
|
|
||||||
# The pipeline of func.func passes that lower the ONNX backend contract to the
|
# The pipeline of func.func passes that lower the ONNX backend contract to the
|
||||||
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||||
ONNX_TO_TORCH_FUNC_PIPELINE = ",".join([
|
ONNX_TO_TORCH_FUNC_PIPELINE = ",".join(
|
||||||
|
[
|
||||||
"convert-torch-onnx-to-torch",
|
"convert-torch-onnx-to-torch",
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
||||||
|
@ -50,9 +54,14 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
||||||
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract")
|
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
|
||||||
|
)
|
||||||
|
|
||||||
backend_legal_ops = ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int']
|
backend_legal_ops = [
|
||||||
|
"aten.flatten.using_ints",
|
||||||
|
"aten.adaptive_avg_pool1d",
|
||||||
|
"aten.unflatten.int",
|
||||||
|
]
|
||||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
|
@ -60,7 +69,9 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
||||||
"Lowering TorchFX IR -> Torch Backend IR",
|
"Lowering TorchFX IR -> Torch Backend IR",
|
||||||
)
|
)
|
||||||
|
|
||||||
imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
|
imported_module = lower_mlir_module(
|
||||||
|
False, OutputType.LINALG_ON_TENSORS, imported_module
|
||||||
|
)
|
||||||
compiled_module = self.refbackend.compile(imported_module)
|
compiled_module = self.refbackend.compile(imported_module)
|
||||||
return compiled_module
|
return compiled_module
|
||||||
|
|
||||||
|
|
|
@ -23,18 +23,23 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
||||||
test's `program_factory` is taken from `module_factory`, and the
|
test's `program_factory` is taken from `module_factory`, and the
|
||||||
`program_invoker` is the decorated function.
|
`program_invoker` is the decorated function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
# Ensure that there are no duplicate names in the global test registry.
|
# Ensure that there are no duplicate names in the global test registry.
|
||||||
if f.__name__ in _SEEN_UNIQUE_NAMES:
|
if f.__name__ in _SEEN_UNIQUE_NAMES:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name.")
|
f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name."
|
||||||
|
)
|
||||||
_SEEN_UNIQUE_NAMES.add(f.__name__)
|
_SEEN_UNIQUE_NAMES.add(f.__name__)
|
||||||
|
|
||||||
# Store the test in the registry.
|
# Store the test in the registry.
|
||||||
GLOBAL_TEST_REGISTRY.append(
|
GLOBAL_TEST_REGISTRY.append(
|
||||||
Test(unique_name=f.__name__,
|
Test(
|
||||||
|
unique_name=f.__name__,
|
||||||
program_factory=module_factory,
|
program_factory=module_factory,
|
||||||
program_invoker=f))
|
program_invoker=f,
|
||||||
|
)
|
||||||
|
)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
|
@ -19,6 +19,7 @@ from .framework import TestResult, TraceItem
|
||||||
|
|
||||||
class TensorSummary:
|
class TensorSummary:
|
||||||
"""A summary of a tensor's contents."""
|
"""A summary of a tensor's contents."""
|
||||||
|
|
||||||
def __init__(self, tensor):
|
def __init__(self, tensor):
|
||||||
self.min = torch.min(tensor.type(torch.float64))
|
self.min = torch.min(tensor.type(torch.float64))
|
||||||
self.max = torch.max(tensor.type(torch.float64))
|
self.max = torch.max(tensor.type(torch.float64))
|
||||||
|
@ -27,7 +28,7 @@ class TensorSummary:
|
||||||
self.dtype = tensor.dtype
|
self.dtype = tensor.dtype
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'
|
return f"Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}"
|
||||||
|
|
||||||
|
|
||||||
class ErrorContext:
|
class ErrorContext:
|
||||||
|
@ -35,6 +36,7 @@ class ErrorContext:
|
||||||
|
|
||||||
This is useful for tracking errors across multiple levels of detail.
|
This is useful for tracking errors across multiple levels of detail.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, contexts: List[str]):
|
def __init__(self, contexts: List[str]):
|
||||||
self.contexts = contexts
|
self.contexts = contexts
|
||||||
|
|
||||||
|
@ -47,17 +49,16 @@ class ErrorContext:
|
||||||
return ErrorContext([])
|
return ErrorContext([])
|
||||||
|
|
||||||
def chain(self, additional_context: str):
|
def chain(self, additional_context: str):
|
||||||
"""Chain an additional context onto the current error context.
|
"""Chain an additional context onto the current error context."""
|
||||||
"""
|
|
||||||
return ErrorContext(self.contexts + [additional_context])
|
return ErrorContext(self.contexts + [additional_context])
|
||||||
|
|
||||||
def format_error(self, s: str):
|
def format_error(self, s: str):
|
||||||
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
|
return "@ " + "\n@ ".join(self.contexts) + "\n" + "ERROR: " + s
|
||||||
|
|
||||||
|
|
||||||
class ValueReport:
|
class ValueReport:
|
||||||
"""A report for a single value processed by the program.
|
"""A report for a single value processed by the program."""
|
||||||
"""
|
|
||||||
def __init__(self, value, golden_value, context: ErrorContext):
|
def __init__(self, value, golden_value, context: ErrorContext):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.golden_value = golden_value
|
self.golden_value = golden_value
|
||||||
|
@ -70,7 +71,7 @@ class ValueReport:
|
||||||
return len(self.failure_reasons) != 0
|
return len(self.failure_reasons) != 0
|
||||||
|
|
||||||
def error_str(self):
|
def error_str(self):
|
||||||
return '\n'.join(self.failure_reasons)
|
return "\n".join(self.failure_reasons)
|
||||||
|
|
||||||
def _evaluate_outcome(self):
|
def _evaluate_outcome(self):
|
||||||
value, golden = self.value, self.golden_value
|
value, golden = self.value, self.golden_value
|
||||||
|
@ -80,37 +81,37 @@ class ValueReport:
|
||||||
golden = golden[0]
|
golden = golden[0]
|
||||||
if isinstance(golden, float):
|
if isinstance(golden, float):
|
||||||
if not isinstance(value, float):
|
if not isinstance(value, float):
|
||||||
return self._record_mismatch_type_failure('float', value)
|
return self._record_mismatch_type_failure("float", value)
|
||||||
if abs(value - golden) / golden > 1e-4:
|
if abs(value - golden) / golden > 1e-4:
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'value ({value!r}) is not close to golden value ({golden!r})'
|
f"value ({value!r}) is not close to golden value ({golden!r})"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if isinstance(golden, int):
|
if isinstance(golden, int):
|
||||||
if not isinstance(value, int):
|
if not isinstance(value, int):
|
||||||
return self._record_mismatch_type_failure('int', value)
|
return self._record_mismatch_type_failure("int", value)
|
||||||
if value != golden:
|
if value != golden:
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'value ({value!r}) is not equal to golden value ({golden!r})'
|
f"value ({value!r}) is not equal to golden value ({golden!r})"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if isinstance(golden, str):
|
if isinstance(golden, str):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
return self._record_mismatch_type_failure('str', value)
|
return self._record_mismatch_type_failure("str", value)
|
||||||
if value != golden:
|
if value != golden:
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'value ({value!r}) is not equal to golden value ({golden!r})'
|
f"value ({value!r}) is not equal to golden value ({golden!r})"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if isinstance(golden, tuple):
|
if isinstance(golden, tuple):
|
||||||
if not isinstance(value, tuple):
|
if not isinstance(value, tuple):
|
||||||
return self._record_mismatch_type_failure('tuple', value)
|
return self._record_mismatch_type_failure("tuple", value)
|
||||||
if len(value) != len(golden):
|
if len(value) != len(golden):
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
|
f"value ({len(value)!r}) is not equal to golden value ({len(golden)!r})"
|
||||||
)
|
)
|
||||||
reports = [
|
reports = [
|
||||||
ValueReport(v, g, self.context.chain(f'tuple element {i}'))
|
ValueReport(v, g, self.context.chain(f"tuple element {i}"))
|
||||||
for i, (v, g) in enumerate(zip(value, golden))
|
for i, (v, g) in enumerate(zip(value, golden))
|
||||||
]
|
]
|
||||||
for report in reports:
|
for report in reports:
|
||||||
|
@ -119,13 +120,13 @@ class ValueReport:
|
||||||
return
|
return
|
||||||
if isinstance(golden, list):
|
if isinstance(golden, list):
|
||||||
if not isinstance(value, list):
|
if not isinstance(value, list):
|
||||||
return self._record_mismatch_type_failure('list', value)
|
return self._record_mismatch_type_failure("list", value)
|
||||||
if len(value) != len(golden):
|
if len(value) != len(golden):
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
|
f"value ({len(value)!r}) is not equal to golden value ({len(golden)!r})"
|
||||||
)
|
)
|
||||||
reports = [
|
reports = [
|
||||||
ValueReport(v, g, self.context.chain(f'list element {i}'))
|
ValueReport(v, g, self.context.chain(f"list element {i}"))
|
||||||
for i, (v, g) in enumerate(zip(value, golden))
|
for i, (v, g) in enumerate(zip(value, golden))
|
||||||
]
|
]
|
||||||
for report in reports:
|
for report in reports:
|
||||||
|
@ -134,16 +135,19 @@ class ValueReport:
|
||||||
return
|
return
|
||||||
if isinstance(golden, dict):
|
if isinstance(golden, dict):
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
return self._record_mismatch_type_failure('dict', value)
|
return self._record_mismatch_type_failure("dict", value)
|
||||||
gkeys = list(sorted(golden.keys()))
|
gkeys = list(sorted(golden.keys()))
|
||||||
vkeys = list(sorted(value.keys()))
|
vkeys = list(sorted(value.keys()))
|
||||||
if gkeys != vkeys:
|
if gkeys != vkeys:
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})'
|
f"dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})"
|
||||||
)
|
)
|
||||||
reports = [
|
reports = [
|
||||||
ValueReport(value[k], golden[k],
|
ValueReport(
|
||||||
self.context.chain(f'dict element at key {k!r}'))
|
value[k],
|
||||||
|
golden[k],
|
||||||
|
self.context.chain(f"dict element at key {k!r}"),
|
||||||
|
)
|
||||||
for k in gkeys
|
for k in gkeys
|
||||||
]
|
]
|
||||||
for report in reports:
|
for report in reports:
|
||||||
|
@ -152,40 +156,42 @@ class ValueReport:
|
||||||
return
|
return
|
||||||
if isinstance(golden, torch.Tensor):
|
if isinstance(golden, torch.Tensor):
|
||||||
if not isinstance(value, torch.Tensor):
|
if not isinstance(value, torch.Tensor):
|
||||||
return self._record_mismatch_type_failure('torch.Tensor', value)
|
return self._record_mismatch_type_failure("torch.Tensor", value)
|
||||||
|
|
||||||
if value.shape != golden.shape:
|
if value.shape != golden.shape:
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
|
f"shape ({value.shape}) is not equal to golden shape ({golden.shape})"
|
||||||
)
|
)
|
||||||
if value.dtype != golden.dtype:
|
if value.dtype != golden.dtype:
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
|
f"dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})"
|
||||||
)
|
)
|
||||||
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
if not torch.allclose(
|
||||||
|
value, golden, rtol=1e-03, atol=1e-07, equal_nan=True
|
||||||
|
):
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
|
f"value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'unexpected golden value of type `{golden.__class__.__name__}`')
|
f"unexpected golden value of type `{golden.__class__.__name__}`"
|
||||||
|
)
|
||||||
|
|
||||||
def _record_failure(self, s: str):
|
def _record_failure(self, s: str):
|
||||||
self.failure_reasons.append(self.context.format_error(s))
|
self.failure_reasons.append(self.context.format_error(s))
|
||||||
|
|
||||||
def _record_mismatch_type_failure(self, expected: str, actual: Any):
|
def _record_mismatch_type_failure(self, expected: str, actual: Any):
|
||||||
self._record_failure(
|
self._record_failure(
|
||||||
f'expected a value of type `{expected}` but got `{actual.__class__.__name__}`'
|
f"expected a value of type `{expected}` but got `{actual.__class__.__name__}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TraceItemReport:
|
class TraceItemReport:
|
||||||
"""A report for a single trace item."""
|
"""A report for a single trace item."""
|
||||||
|
|
||||||
failure_reasons: List[str]
|
failure_reasons: List[str]
|
||||||
|
|
||||||
def __init__(self, item: TraceItem, golden_item: TraceItem,
|
def __init__(self, item: TraceItem, golden_item: TraceItem, context: ErrorContext):
|
||||||
context: ErrorContext):
|
|
||||||
self.item = item
|
self.item = item
|
||||||
self.golden_item = golden_item
|
self.golden_item = golden_item
|
||||||
self.context = context
|
self.context = context
|
||||||
|
@ -197,36 +203,43 @@ class TraceItemReport:
|
||||||
return len(self.failure_reasons) != 0
|
return len(self.failure_reasons) != 0
|
||||||
|
|
||||||
def error_str(self):
|
def error_str(self):
|
||||||
return '\n'.join(self.failure_reasons)
|
return "\n".join(self.failure_reasons)
|
||||||
|
|
||||||
def _evaluate_outcome(self):
|
def _evaluate_outcome(self):
|
||||||
if self.item.symbol != self.golden_item.symbol:
|
if self.item.symbol != self.golden_item.symbol:
|
||||||
self.failure_reasons.append(
|
self.failure_reasons.append(
|
||||||
self.context.format_error(
|
self.context.format_error(
|
||||||
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
|
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
|
||||||
))
|
)
|
||||||
|
)
|
||||||
if len(self.item.inputs) != len(self.golden_item.inputs):
|
if len(self.item.inputs) != len(self.golden_item.inputs):
|
||||||
self.failure_reasons.append(
|
self.failure_reasons.append(
|
||||||
self.context.format_error(
|
self.context.format_error(
|
||||||
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
|
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
|
||||||
))
|
)
|
||||||
|
)
|
||||||
for i, (input, golden_input) in enumerate(
|
for i, (input, golden_input) in enumerate(
|
||||||
zip(self.item.inputs, self.golden_item.inputs)):
|
zip(self.item.inputs, self.golden_item.inputs)
|
||||||
|
):
|
||||||
value_report = ValueReport(
|
value_report = ValueReport(
|
||||||
input, golden_input,
|
input,
|
||||||
self.context.chain(
|
golden_input,
|
||||||
f'input #{i} of call to "{self.item.symbol}"'))
|
self.context.chain(f'input #{i} of call to "{self.item.symbol}"'),
|
||||||
|
)
|
||||||
if value_report.failed:
|
if value_report.failed:
|
||||||
self.failure_reasons.append(value_report.error_str())
|
self.failure_reasons.append(value_report.error_str())
|
||||||
value_report = ValueReport(
|
value_report = ValueReport(
|
||||||
self.item.output, self.golden_item.output,
|
self.item.output,
|
||||||
self.context.chain(f'output of call to "{self.item.symbol}"'))
|
self.golden_item.output,
|
||||||
|
self.context.chain(f'output of call to "{self.item.symbol}"'),
|
||||||
|
)
|
||||||
if value_report.failed:
|
if value_report.failed:
|
||||||
self.failure_reasons.append(value_report.error_str())
|
self.failure_reasons.append(value_report.error_str())
|
||||||
|
|
||||||
|
|
||||||
class SingleTestReport:
|
class SingleTestReport:
|
||||||
"""A report for a single test."""
|
"""A report for a single test."""
|
||||||
|
|
||||||
item_reports: Optional[List[TraceItemReport]]
|
item_reports: Optional[List[TraceItemReport]]
|
||||||
|
|
||||||
def __init__(self, result: TestResult, context: ErrorContext):
|
def __init__(self, result: TestResult, context: ErrorContext):
|
||||||
|
@ -236,12 +249,15 @@ class SingleTestReport:
|
||||||
if result.compilation_error is None and result.runtime_error is None:
|
if result.compilation_error is None and result.runtime_error is None:
|
||||||
self.item_reports = []
|
self.item_reports = []
|
||||||
for i, (item, golden_item) in enumerate(
|
for i, (item, golden_item) in enumerate(
|
||||||
zip(result.trace, result.golden_trace)):
|
zip(result.trace, result.golden_trace)
|
||||||
|
):
|
||||||
self.item_reports.append(
|
self.item_reports.append(
|
||||||
TraceItemReport(
|
TraceItemReport(
|
||||||
item, golden_item,
|
item,
|
||||||
context.chain(
|
golden_item,
|
||||||
f'trace item #{i} - call to "{item.symbol}"')))
|
context.chain(f'trace item #{i} - call to "{item.symbol}"'),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def failed(self):
|
def failed(self):
|
||||||
|
@ -256,19 +272,21 @@ class SingleTestReport:
|
||||||
f = io.StringIO()
|
f = io.StringIO()
|
||||||
p = lambda *x: print(*x, file=f)
|
p = lambda *x: print(*x, file=f)
|
||||||
if self.result.compilation_error is not None:
|
if self.result.compilation_error is not None:
|
||||||
return 'Compilation error: ' + self.result.compilation_error
|
return "Compilation error: " + self.result.compilation_error
|
||||||
elif self.result.runtime_error is not None:
|
elif self.result.runtime_error is not None:
|
||||||
return 'Runtime error: ' + self.result.runtime_error
|
return "Runtime error: " + self.result.runtime_error
|
||||||
for report in self.item_reports:
|
for report in self.item_reports:
|
||||||
if report.failed:
|
if report.failed:
|
||||||
p(report.error_str())
|
p(report.error_str())
|
||||||
return f.getvalue()
|
return f.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def report_results(results: List[TestResult],
|
def report_results(
|
||||||
|
results: List[TestResult],
|
||||||
expected_failures: Set[str],
|
expected_failures: Set[str],
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
config: str = ""):
|
config: str = "",
|
||||||
|
):
|
||||||
"""Print a basic error report summarizing various TestResult's.
|
"""Print a basic error report summarizing various TestResult's.
|
||||||
|
|
||||||
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
|
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
|
||||||
|
@ -293,49 +311,50 @@ def report_results(results: List[TestResult],
|
||||||
if expected_failure:
|
if expected_failure:
|
||||||
if report.failed:
|
if report.failed:
|
||||||
print(f'XFAIL - "{result.unique_name}"')
|
print(f'XFAIL - "{result.unique_name}"')
|
||||||
results_by_outcome['XFAIL'].append((result, report))
|
results_by_outcome["XFAIL"].append((result, report))
|
||||||
else:
|
else:
|
||||||
print(f'XPASS - "{result.unique_name}"')
|
print(f'XPASS - "{result.unique_name}"')
|
||||||
results_by_outcome['XPASS'].append((result, report))
|
results_by_outcome["XPASS"].append((result, report))
|
||||||
else:
|
else:
|
||||||
if not report.failed:
|
if not report.failed:
|
||||||
print(f'PASS - "{result.unique_name}"')
|
print(f'PASS - "{result.unique_name}"')
|
||||||
results_by_outcome['PASS'].append((result, report))
|
results_by_outcome["PASS"].append((result, report))
|
||||||
else:
|
else:
|
||||||
print(f'FAIL - "{result.unique_name}"')
|
print(f'FAIL - "{result.unique_name}"')
|
||||||
results_by_outcome['FAIL'].append((result, report))
|
results_by_outcome["FAIL"].append((result, report))
|
||||||
|
|
||||||
OUTCOME_MEANINGS = collections.OrderedDict()
|
OUTCOME_MEANINGS = collections.OrderedDict()
|
||||||
OUTCOME_MEANINGS['PASS'] = 'Passed'
|
OUTCOME_MEANINGS["PASS"] = "Passed"
|
||||||
OUTCOME_MEANINGS['FAIL'] = 'Failed'
|
OUTCOME_MEANINGS["FAIL"] = "Failed"
|
||||||
OUTCOME_MEANINGS['XFAIL'] = 'Expectedly Failed'
|
OUTCOME_MEANINGS["XFAIL"] = "Expectedly Failed"
|
||||||
OUTCOME_MEANINGS['XPASS'] = 'Unexpectedly Passed'
|
OUTCOME_MEANINGS["XPASS"] = "Unexpectedly Passed"
|
||||||
|
|
||||||
had_unexpected_results = len(results_by_outcome['FAIL']) != 0 or len(
|
had_unexpected_results = (
|
||||||
results_by_outcome['XPASS']) != 0
|
len(results_by_outcome["FAIL"]) != 0 or len(results_by_outcome["XPASS"]) != 0
|
||||||
|
)
|
||||||
|
|
||||||
if had_unexpected_results:
|
if had_unexpected_results:
|
||||||
print(f'\nUnexpected outcome summary: ({config})')
|
print(f"\nUnexpected outcome summary: ({config})")
|
||||||
|
|
||||||
# For FAIL and XPASS (unexpected outcomes), print a summary.
|
# For FAIL and XPASS (unexpected outcomes), print a summary.
|
||||||
for outcome, results in results_by_outcome.items():
|
for outcome, results in results_by_outcome.items():
|
||||||
# PASS and XFAIL are "good"/"successful" outcomes.
|
# PASS and XFAIL are "good"/"successful" outcomes.
|
||||||
if outcome == 'PASS' or outcome == 'XFAIL':
|
if outcome == "PASS" or outcome == "XFAIL":
|
||||||
continue
|
continue
|
||||||
# If there is nothing to report, be quiet.
|
# If there is nothing to report, be quiet.
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
continue
|
continue
|
||||||
print(f'\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests')
|
print(f"\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests")
|
||||||
for result, report in results:
|
for result, report in results:
|
||||||
print(f' {outcome} - "{result.unique_name}"')
|
print(f' {outcome} - "{result.unique_name}"')
|
||||||
# If the test failed, print the error message.
|
# If the test failed, print the error message.
|
||||||
if outcome == 'FAIL' and verbose:
|
if outcome == "FAIL" and verbose:
|
||||||
print(textwrap.indent(report.error_str(), ' ' * 8))
|
print(textwrap.indent(report.error_str(), " " * 8))
|
||||||
|
|
||||||
# Print a summary for easy scanning.
|
# Print a summary for easy scanning.
|
||||||
print('\nSummary:')
|
print("\nSummary:")
|
||||||
|
|
||||||
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
|
for key in ["PASS", "FAIL", "XFAIL", "XPASS"]:
|
||||||
if results_by_outcome[key]:
|
if results_by_outcome[key]:
|
||||||
print(f' {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}')
|
print(f" {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}")
|
||||||
return had_unexpected_results
|
return had_unexpected_results
|
||||||
|
|
|
@ -7,7 +7,9 @@ from torch_mlir.ir import *
|
||||||
from torch_mlir.passmanager import *
|
from torch_mlir.passmanager import *
|
||||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||||
|
RefBackendLinalgOnTensorsBackend,
|
||||||
|
)
|
||||||
|
|
||||||
from .abc import StablehloBackend
|
from .abc import StablehloBackend
|
||||||
|
|
||||||
|
@ -17,11 +19,13 @@ __all__ = [
|
||||||
|
|
||||||
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
|
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
|
||||||
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||||
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
|
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join(
|
||||||
|
[
|
||||||
"func.func(stablehlo-aggressive-simplification)",
|
"func.func(stablehlo-aggressive-simplification)",
|
||||||
"stablehlo-legalize-to-linalg",
|
"stablehlo-legalize-to-linalg",
|
||||||
"canonicalize"
|
"canonicalize",
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
||||||
|
@ -47,7 +51,8 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
|
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
|
||||||
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract")
|
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract",
|
||||||
|
)
|
||||||
|
|
||||||
return self.refbackend.compile(imported_module)
|
return self.refbackend.compile(imported_module)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def register_all_tests():
|
def register_all_tests():
|
||||||
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
|
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
|
||||||
# Side-effecting import statements.
|
# Side-effecting import statements.
|
||||||
|
|
|
@ -17,13 +17,15 @@ class ArangeIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(5)
|
return torch.arange(5)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeIntModule())
|
@register_test_case(module_factory=lambda: ArangeIntModule())
|
||||||
def ArangeIntModule_basic(module, tu: TestUtils):
|
def ArangeIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -34,13 +36,15 @@ class ArangeFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(5.0)
|
return torch.arange(5.0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeFloatModule())
|
@register_test_case(module_factory=lambda: ArangeFloatModule())
|
||||||
def ArangeFloatModule_basic(module, tu: TestUtils):
|
def ArangeFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -51,31 +55,37 @@ class ArangeZeroElementOutputModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(0)
|
return torch.arange(0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
|
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
|
||||||
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ArangeStartIntModule(torch.nn.Module):
|
class ArangeStartIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(0, 5)
|
return torch.arange(0, 5)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartIntModule())
|
@register_test_case(module_factory=lambda: ArangeStartIntModule())
|
||||||
def ArangeStartIntModule_basic(module, tu: TestUtils):
|
def ArangeStartIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -86,13 +96,15 @@ class ArangeStartFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(0.0, 5.0)
|
return torch.arange(0.0, 5.0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
|
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
|
||||||
def ArangeStartFloatModule_basic(module, tu: TestUtils):
|
def ArangeStartFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -103,13 +115,15 @@ class ArangeNegativeStartIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(-10, 5)
|
return torch.arange(-10, 5)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
|
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
|
||||||
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
|
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -120,31 +134,37 @@ class ArangeNegativeStartFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(-1.4, 5.7)
|
return torch.arange(-1.4, 5.7)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
|
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
|
||||||
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ArangeStartStepIntModule(torch.nn.Module):
|
class ArangeStartStepIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(0, 5, 1)
|
return torch.arange(0, 5, 1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
|
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
|
||||||
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
|
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -155,13 +175,15 @@ class ArangeStartStepFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(-1, 5, 1.3)
|
return torch.arange(-1, 5, 1.3)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
|
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
|
||||||
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
|
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -172,13 +194,15 @@ class ArangeStartNegativeStepIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(10, 1, -2)
|
return torch.arange(10, 1, -2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
|
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
|
||||||
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
|
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -189,31 +213,37 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(-1, -15, -3.4)
|
return torch.arange(-1, -15, -3.4)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
|
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
|
||||||
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ArangeDtypeFloatModule(torch.nn.Module):
|
class ArangeDtypeFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(-1, 15, dtype=torch.float32)
|
return torch.arange(-1, 15, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
|
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
|
||||||
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
|
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -224,110 +254,137 @@ class ArangeDtypeIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(0.2, 5.0, dtype=torch.int64)
|
return torch.arange(0.2, 5.0, dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
|
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
|
||||||
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
|
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
||||||
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ArangeStartOutModule(torch.nn.Module):
|
class ArangeStartOutModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([12], torch.int64, True),
|
([12], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.arange(start=0, end=12, out=x)
|
return torch.arange(start=0, end=12, out=x)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartOutModule())
|
@register_test_case(module_factory=lambda: ArangeStartOutModule())
|
||||||
def ArangeStartOutModule_basic(module, tu: TestUtils):
|
def ArangeStartOutModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.zeros(12).to(torch.int64))
|
module.forward(torch.zeros(12).to(torch.int64))
|
||||||
|
|
||||||
|
|
||||||
class ArangeStartOutViewModule(torch.nn.Module):
|
class ArangeStartOutViewModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3, 4], torch.int64, True),
|
([3, 4], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.arange(start=1, end=13, out=x)
|
return torch.arange(start=1, end=13, out=x)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
|
@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
|
||||||
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
|
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.zeros(3, 4).to(torch.int64))
|
module.forward(torch.zeros(3, 4).to(torch.int64))
|
||||||
|
|
||||||
|
|
||||||
class ArangeStartOutDtypeModule(torch.nn.Module):
|
class ArangeStartOutDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([12], torch.int64, True),
|
([12], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.arange(start=1.1, end=13.1, out=x)
|
return torch.arange(start=1.1, end=13.1, out=x)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
||||||
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.zeros(12).to(torch.int64))
|
module.forward(torch.zeros(12).to(torch.int64))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class LinspaceModule(torch.nn.Module):
|
class LinspaceModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.linspace(-10.1, 10.1, 10)
|
return torch.linspace(-10.1, 10.1, 10)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: LinspaceModule())
|
@register_test_case(module_factory=lambda: LinspaceModule())
|
||||||
def LinspaceModule_basic(module, tu: TestUtils):
|
def LinspaceModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
class LinspaceDtypeModule(torch.nn.Module):
|
class LinspaceDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64)
|
return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64)
|
||||||
|
|
||||||
|
@ -336,47 +393,59 @@ class LinspaceDtypeModule(torch.nn.Module):
|
||||||
def LinspaceDtypeModule_basic(module, tu: TestUtils):
|
def LinspaceDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
class LinspaceEmptyModule(torch.nn.Module):
|
class LinspaceEmptyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.linspace(-10.1, 10.1, 0)
|
return torch.linspace(-10.1, 10.1, 0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: LinspaceEmptyModule())
|
@register_test_case(module_factory=lambda: LinspaceEmptyModule())
|
||||||
def LinspaceEmptyModule_basic(module, tu: TestUtils):
|
def LinspaceEmptyModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
class LinspaceOneSizeModule(torch.nn.Module):
|
class LinspaceOneSizeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.linspace(-10.1, 10.1, 1)
|
return torch.linspace(-10.1, 10.1, 1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: LinspaceOneSizeModule())
|
@register_test_case(module_factory=lambda: LinspaceOneSizeModule())
|
||||||
def LinspaceOneSizeModule_basic(module, tu: TestUtils):
|
def LinspaceOneSizeModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
class LinspaceTwoSizeModule(torch.nn.Module):
|
class LinspaceTwoSizeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.linspace(-10.1, 10.1, 2)
|
return torch.linspace(-10.1, 10.1, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
||||||
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
@ -387,12 +456,16 @@ class PrimsIotaModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu',
|
return torch.ops.prims.iota(
|
||||||
requires_grad=False)
|
77, start=0, step=1, dtype=torch.int64, device="cpu", requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: PrimsIotaModule())
|
@register_test_case(module_factory=lambda: PrimsIotaModule())
|
||||||
def PrimsIotaModule_basic(module, tu: TestUtils):
|
def PrimsIotaModule_basic(module, tu: TestUtils):
|
||||||
|
|
|
@ -13,21 +13,21 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxBackwardModule(torch.nn.Module):
|
class SoftmaxBackwardModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, output):
|
def forward(self, grad_output, output):
|
||||||
return torch.ops.aten._softmax_backward_data(grad_output,
|
return torch.ops.aten._softmax_backward_data(
|
||||||
output,
|
grad_output, output, dim=1, input_dtype=6
|
||||||
dim=1,
|
)
|
||||||
input_dtype=6)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
|
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
|
||||||
|
@ -37,16 +37,17 @@ def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class TanhBackwardModule(torch.nn.Module):
|
class TanhBackwardModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_out, output):
|
def forward(self, grad_out, output):
|
||||||
return torch.ops.aten.tanh_backward(grad_out, output)
|
return torch.ops.aten.tanh_backward(grad_out, output)
|
||||||
|
|
||||||
|
@ -58,40 +59,46 @@ def TanhBackward_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class HardtanhBackwardModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class HardtanhBackwardModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_out, input):
|
def forward(self, grad_out, input):
|
||||||
return torch.ops.aten.hardtanh_backward(grad_out, input, min_val=0.2, max_val=0.5)
|
return torch.ops.aten.hardtanh_backward(
|
||||||
|
grad_out, input, min_val=0.2, max_val=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: HardtanhBackwardModule())
|
@register_test_case(module_factory=lambda: HardtanhBackwardModule())
|
||||||
def HardtanhBackward_basic(module, tu: TestUtils):
|
def HardtanhBackward_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(10, 20), tu.rand(10, 20))
|
module.forward(tu.rand(10, 20), tu.rand(10, 20))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionBackwardModule2D(torch.nn.Module):
|
class ConvolutionBackwardModule2D(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_out, input_vec, weight):
|
def forward(self, grad_out, input_vec, weight):
|
||||||
return torch.ops.aten.convolution_backward(
|
return torch.ops.aten.convolution_backward(
|
||||||
grad_out,
|
grad_out,
|
||||||
|
@ -104,27 +111,29 @@ class ConvolutionBackwardModule2D(torch.nn.Module):
|
||||||
transposed=False,
|
transposed=False,
|
||||||
output_padding=[0],
|
output_padding=[0],
|
||||||
groups=1,
|
groups=1,
|
||||||
output_mask=[True, True, True])
|
output_mask=[True, True, True],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2D())
|
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2D())
|
||||||
def ConvolutionBackwardModule2D_basic(module, tu: TestUtils):
|
def ConvolutionBackwardModule2D_basic(module, tu: TestUtils):
|
||||||
with torch.backends.mkldnn.flags(enabled=False):
|
with torch.backends.mkldnn.flags(enabled=False):
|
||||||
module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6),
|
module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6), tu.rand(2, 2, 2, 2))
|
||||||
tu.rand(2, 2, 2, 2))
|
|
||||||
|
|
||||||
class ConvolutionBackwardModule2DStatic(torch.nn.Module):
|
class ConvolutionBackwardModule2DStatic(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 4, 64, 64], torch.float32, True),
|
([1, 4, 64, 64], torch.float32, True),
|
||||||
([1, 320, 64, 64], torch.float32, True),
|
([1, 320, 64, 64], torch.float32, True),
|
||||||
([4, 320, 3, 3], torch.float32, True),
|
([4, 320, 3, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_out, input_vec, weight):
|
def forward(self, grad_out, input_vec, weight):
|
||||||
return torch.ops.aten.convolution_backward(
|
return torch.ops.aten.convolution_backward(
|
||||||
grad_out,
|
grad_out,
|
||||||
|
@ -137,28 +146,31 @@ class ConvolutionBackwardModule2DStatic(torch.nn.Module):
|
||||||
transposed=False,
|
transposed=False,
|
||||||
output_padding=[0, 0],
|
output_padding=[0, 0],
|
||||||
groups=1,
|
groups=1,
|
||||||
output_mask=[True, True, True])
|
output_mask=[True, True, True],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStatic())
|
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStatic())
|
||||||
def ConvolutionBackwardModule2DStatic_basic(module, tu: TestUtils):
|
def ConvolutionBackwardModule2DStatic_basic(module, tu: TestUtils):
|
||||||
with torch.backends.mkldnn.flags(enabled=False):
|
with torch.backends.mkldnn.flags(enabled=False):
|
||||||
module.forward(tu.rand(1, 4, 64, 64), tu.rand(1, 320, 64, 64),
|
module.forward(
|
||||||
tu.rand(4, 320, 3, 3))
|
tu.rand(1, 4, 64, 64), tu.rand(1, 320, 64, 64), tu.rand(4, 320, 3, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_out, input_vec, weight):
|
def forward(self, grad_out, input_vec, weight):
|
||||||
return torch.ops.aten.convolution_backward(
|
return torch.ops.aten.convolution_backward(
|
||||||
grad_out,
|
grad_out,
|
||||||
|
@ -171,28 +183,29 @@ class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
||||||
transposed=False,
|
transposed=False,
|
||||||
output_padding=[0],
|
output_padding=[0],
|
||||||
groups=1,
|
groups=1,
|
||||||
output_mask=[True, True, True])
|
output_mask=[True, True, True],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DPadded())
|
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DPadded())
|
||||||
def ConvolutionBackwardModule2DPadded_basic(module, tu: TestUtils):
|
def ConvolutionBackwardModule2DPadded_basic(module, tu: TestUtils):
|
||||||
with torch.backends.mkldnn.flags(enabled=False):
|
with torch.backends.mkldnn.flags(enabled=False):
|
||||||
module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6),
|
module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6), tu.rand(2, 2, 3, 3))
|
||||||
tu.rand(2, 2, 3, 3))
|
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionBackwardModule2DStrided(torch.nn.Module):
|
class ConvolutionBackwardModule2DStrided(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 2, 4, 4], torch.float32, True),
|
([1, 2, 4, 4], torch.float32, True),
|
||||||
([1, 2, 8, 8], torch.float32, True),
|
([1, 2, 8, 8], torch.float32, True),
|
||||||
([2, 2, 3, 3], torch.float32, True),
|
([2, 2, 3, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_out, input_vec, weight):
|
def forward(self, grad_out, input_vec, weight):
|
||||||
return torch.ops.aten.convolution_backward(
|
return torch.ops.aten.convolution_backward(
|
||||||
grad_out,
|
grad_out,
|
||||||
|
@ -205,30 +218,31 @@ class ConvolutionBackwardModule2DStrided(torch.nn.Module):
|
||||||
transposed=False,
|
transposed=False,
|
||||||
output_padding=[0, 0],
|
output_padding=[0, 0],
|
||||||
groups=1,
|
groups=1,
|
||||||
output_mask=[True, True, True])
|
output_mask=[True, True, True],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStrided())
|
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStrided())
|
||||||
def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils):
|
def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils):
|
||||||
with torch.backends.mkldnn.flags(enabled=False):
|
with torch.backends.mkldnn.flags(enabled=False):
|
||||||
module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8),
|
module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), tu.rand(2, 2, 3, 3))
|
||||||
tu.rand(2, 2, 3, 3))
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class GeluBackwardModule(torch.nn.Module):
|
class GeluBackwardModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.gelu_backward(grad, input)
|
return torch.ops.aten.gelu_backward(grad, input)
|
||||||
|
|
||||||
|
@ -239,21 +253,21 @@ def GeluBackwardModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class LogSoftmaxBackwardModule(torch.nn.Module):
|
class LogSoftmaxBackwardModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, output):
|
def forward(self, grad_output, output):
|
||||||
return torch.ops.aten._log_softmax_backward_data(grad_output,
|
return torch.ops.aten._log_softmax_backward_data(
|
||||||
output,
|
grad_output, output, dim=1, input_dtype=6
|
||||||
dim=1,
|
)
|
||||||
input_dtype=6)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: LogSoftmaxBackwardModule())
|
@register_test_case(module_factory=lambda: LogSoftmaxBackwardModule())
|
||||||
|
@ -265,18 +279,21 @@ def LogSoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class LeakyReluBackwardModule(torch.nn.Module):
|
class LeakyReluBackwardModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.leaky_relu_backward(grad, input, negative_slope=0.1, self_is_result=False)
|
return torch.ops.aten.leaky_relu_backward(
|
||||||
|
grad, input, negative_slope=0.1, self_is_result=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: LeakyReluBackwardModule())
|
@register_test_case(module_factory=lambda: LeakyReluBackwardModule())
|
||||||
|
@ -285,18 +302,21 @@ def LeakyReluBackwardModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class LeakyReluBackwardStaticModule(torch.nn.Module):
|
class LeakyReluBackwardStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3, 4, 5], torch.float32, True),
|
([3, 4, 5], torch.float32, True),
|
||||||
([3, 4, 5], torch.float32, True),
|
([3, 4, 5], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.leaky_relu_backward(grad, input, negative_slope=0.1, self_is_result=False)
|
return torch.ops.aten.leaky_relu_backward(
|
||||||
|
grad, input, negative_slope=0.1, self_is_result=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule())
|
@register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule())
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorToIntZeroRank(torch.nn.Module):
|
class TensorToIntZeroRank(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return int(x)
|
return int(x)
|
||||||
|
|
||||||
|
@ -28,17 +31,21 @@ class TensorToIntZeroRank(torch.nn.Module):
|
||||||
def TensorToIntZeroRank_basic(module, tu: TestUtils):
|
def TensorToIntZeroRank_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(high=10))
|
module.forward(tu.randint(high=10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorToInt(torch.nn.Module):
|
class TensorToInt(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return int(x)
|
return int(x)
|
||||||
|
|
||||||
|
@ -47,17 +54,21 @@ class TensorToInt(torch.nn.Module):
|
||||||
def TensorToInt_basic(module, tu: TestUtils):
|
def TensorToInt_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(1, 1, high=10))
|
module.forward(tu.randint(1, 1, high=10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorToFloatZeroRank(torch.nn.Module):
|
class TensorToFloatZeroRank(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return float(x)
|
return float(x)
|
||||||
|
|
||||||
|
@ -66,17 +77,21 @@ class TensorToFloatZeroRank(torch.nn.Module):
|
||||||
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
|
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand().to(torch.float64))
|
module.forward(tu.rand().to(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorToFloat(torch.nn.Module):
|
class TensorToFloat(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return float(x)
|
return float(x)
|
||||||
|
|
||||||
|
@ -85,17 +100,21 @@ class TensorToFloat(torch.nn.Module):
|
||||||
def TensorToFloat_basic(module, tu: TestUtils):
|
def TensorToFloat_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1).to(torch.float64))
|
module.forward(tu.rand(1, 1).to(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorToBoolZeroRank(torch.nn.Module):
|
class TensorToBoolZeroRank(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.bool, True),
|
([], torch.bool, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return bool(x)
|
return bool(x)
|
||||||
|
|
||||||
|
@ -104,17 +123,21 @@ class TensorToBoolZeroRank(torch.nn.Module):
|
||||||
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
|
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.tensor(1, dtype=torch.bool))
|
module.forward(torch.tensor(1, dtype=torch.bool))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorToBool(torch.nn.Module):
|
class TensorToBool(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.bool, True),
|
([-1, -1], torch.bool, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return bool(x)
|
return bool(x)
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -13,15 +13,13 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TorchPrimLoopForLikeModule(torch.nn.Module):
|
class TorchPrimLoopForLikeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||||
None,
|
|
||||||
([-1, -1], torch.int64, True)
|
|
||||||
])
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_val = x.size(0)
|
x_val = x.size(0)
|
||||||
sum = 0
|
sum = 0
|
||||||
|
@ -34,20 +32,18 @@ class TorchPrimLoopForLikeModule(torch.nn.Module):
|
||||||
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
|
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(6, 8, high=10))
|
module.forward(tu.randint(6, 8, high=10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||||
None,
|
|
||||||
([-1, -1], torch.int64, True)
|
|
||||||
])
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_val = x.size(0)
|
x_val = x.size(0)
|
||||||
sum = 0
|
sum = 0
|
||||||
while(x_val > sum):
|
while x_val > sum:
|
||||||
sum += 1
|
sum += 1
|
||||||
return sum
|
return sum
|
||||||
|
|
||||||
|
@ -59,20 +55,24 @@ def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TorchPrimLoopForLikeTensorArgModule(torch.nn.Module):
|
class TorchPrimLoopForLikeTensorArgModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([7,9], torch.float32, True),
|
([7, 9], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
x = x + i
|
x = x + i
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TorchPrimLoopForLikeTensorArgModule())
|
@register_test_case(module_factory=lambda: TorchPrimLoopForLikeTensorArgModule())
|
||||||
def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
|
def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
|
||||||
x_test = torch.zeros([7, 9]).float()
|
x_test = torch.zeros([7, 9]).float()
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,15 +17,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
# the PyTorch op registry permanently.
|
# the PyTorch op registry permanently.
|
||||||
import torch_mlir._torch_mlir_custom_op_example
|
import torch_mlir._torch_mlir_custom_op_example
|
||||||
|
|
||||||
|
|
||||||
class CustomOpExampleModule(torch.nn.Module):
|
class CustomOpExampleModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops._torch_mlir_custom_op_example.identity(a)
|
return torch.ops._torch_mlir_custom_op_example.identity(a)
|
||||||
|
|
||||||
|
|
|
@ -10,16 +10,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalModule(torch.nn.Module):
|
class DiagonalModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.aten.diagonal(a)
|
return torch.ops.aten.diagonal(a)
|
||||||
|
|
||||||
|
@ -28,96 +30,122 @@ class DiagonalModule(torch.nn.Module):
|
||||||
def DiagonalModule_basic(module, tu: TestUtils):
|
def DiagonalModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 3))
|
module.forward(tu.rand(3, 3))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DiagonalModule())
|
@register_test_case(module_factory=lambda: DiagonalModule())
|
||||||
def DiagonalModule_nonsquare(module, tu: TestUtils):
|
def DiagonalModule_nonsquare(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalTransposedModule(torch.nn.Module):
|
class DiagonalTransposedModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.aten.diagonal(a, dim1=1, dim2=0)
|
return torch.ops.aten.diagonal(a, dim1=1, dim2=0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DiagonalTransposedModule())
|
@register_test_case(module_factory=lambda: DiagonalTransposedModule())
|
||||||
def DiagonalModule_transposed(module, tu: TestUtils):
|
def DiagonalModule_transposed(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalWithDimsModule(torch.nn.Module):
|
class DiagonalWithDimsModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.aten.diagonal(a, dim1=0, dim2=1)
|
return torch.ops.aten.diagonal(a, dim1=0, dim2=1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DiagonalWithDimsModule())
|
@register_test_case(module_factory=lambda: DiagonalWithDimsModule())
|
||||||
def DiagonalModule_with_dims(module, tu: TestUtils):
|
def DiagonalModule_with_dims(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalWithNegativeDimsModule(torch.nn.Module):
|
class DiagonalWithNegativeDimsModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1)
|
return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule())
|
@register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule())
|
||||||
def DiagonalModule_with_negative_dims(module, tu: TestUtils):
|
def DiagonalModule_with_negative_dims(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalWithOffsetModule(torch.nn.Module):
|
class DiagonalWithOffsetModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.aten.diagonal(a, offset=1)
|
return torch.ops.aten.diagonal(a, offset=1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DiagonalWithOffsetModule())
|
@register_test_case(module_factory=lambda: DiagonalWithOffsetModule())
|
||||||
def DiagonalModule_with_offset(module, tu: TestUtils):
|
def DiagonalModule_with_offset(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 6))
|
module.forward(tu.rand(4, 6))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalWithDimsOffsetModule(torch.nn.Module):
|
class DiagonalWithDimsOffsetModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1)
|
return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule())
|
@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule())
|
||||||
def DiagonalModule_with_dims_and_offset(module, tu: TestUtils):
|
def DiagonalModule_with_dims_and_offset(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,15 +11,18 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.gt(x, 0.6)
|
return torch.gt(x, 0.6)
|
||||||
|
|
||||||
|
@ -28,17 +31,21 @@ class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
||||||
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGtIntScalarModule(torch.nn.Module):
|
class ElementwiseGtIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.gt(x, 10)
|
return torch.gt(x, 10)
|
||||||
|
|
||||||
|
@ -47,17 +54,21 @@ class ElementwiseGtIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int32, True),
|
([-1, -1], torch.int32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.gt(x, 7)
|
return torch.gt(x, 7)
|
||||||
|
|
||||||
|
@ -66,17 +77,21 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
||||||
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ge(x, 0.6)
|
return torch.ge(x, 0.6)
|
||||||
|
|
||||||
|
@ -85,17 +100,21 @@ class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
||||||
def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGeIntScalarModule(torch.nn.Module):
|
class ElementwiseGeIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ge(x, 10)
|
return torch.ge(x, 10)
|
||||||
|
|
||||||
|
@ -104,17 +123,21 @@ class ElementwiseGeIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int32, True),
|
([-1, -1], torch.int32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ge(x, 7)
|
return torch.ge(x, 7)
|
||||||
|
|
||||||
|
@ -123,17 +146,21 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ge(x, 7)
|
return torch.ge(x, 7)
|
||||||
|
|
||||||
|
@ -142,18 +169,22 @@ class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGeFloatTensorModule(torch.nn.Module):
|
class ElementwiseGeFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ge(x, y)
|
return torch.ge(x, y)
|
||||||
|
|
||||||
|
@ -162,20 +193,25 @@ class ElementwiseGeFloatTensorModule(torch.nn.Module):
|
||||||
def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGeIntTensorModule(torch.nn.Module):
|
class ElementwiseGeIntTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ge(x, y)
|
return torch.ge(x, y)
|
||||||
|
|
||||||
|
@ -184,18 +220,22 @@ class ElementwiseGeIntTensorModule(torch.nn.Module):
|
||||||
def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.gt(x, y)
|
return torch.gt(x, y)
|
||||||
|
|
||||||
|
@ -204,20 +244,25 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||||
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGtIntTensorModule(torch.nn.Module):
|
class ElementwiseGtIntTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.gt(x, y)
|
return torch.gt(x, y)
|
||||||
|
|
||||||
|
@ -226,17 +271,21 @@ class ElementwiseGtIntTensorModule(torch.nn.Module):
|
||||||
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.lt(x, 0.6)
|
return torch.lt(x, 0.6)
|
||||||
|
|
||||||
|
@ -245,17 +294,21 @@ class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
||||||
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLtIntScalarModule(torch.nn.Module):
|
class ElementwiseLtIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.lt(x, 0)
|
return torch.lt(x, 0)
|
||||||
|
|
||||||
|
@ -264,37 +317,44 @@ class ElementwiseLtIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int32, True),
|
([-1, -1], torch.int32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.lt(x, 2)
|
return torch.lt(x, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
|
||||||
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
|
|
||||||
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.le(x, 0.6)
|
return torch.le(x, 0.6)
|
||||||
|
|
||||||
|
@ -303,17 +363,21 @@ class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
||||||
def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLeIntScalarModule(torch.nn.Module):
|
class ElementwiseLeIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.le(x, 10)
|
return torch.le(x, 10)
|
||||||
|
|
||||||
|
@ -322,17 +386,21 @@ class ElementwiseLeIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int32, True),
|
([-1, -1], torch.int32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.le(x, 7)
|
return torch.le(x, 7)
|
||||||
|
|
||||||
|
@ -341,17 +409,21 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.le(x, 7)
|
return torch.le(x, 7)
|
||||||
|
|
||||||
|
@ -360,18 +432,22 @@ class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLeFloatTensorModule(torch.nn.Module):
|
class ElementwiseLeFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.le(x, y)
|
return torch.le(x, y)
|
||||||
|
|
||||||
|
@ -380,18 +456,22 @@ class ElementwiseLeFloatTensorModule(torch.nn.Module):
|
||||||
def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLeFloatTensorNanModule(torch.nn.Module):
|
class ElementwiseLeFloatTensorNanModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.le(x, y)
|
return torch.le(x, y)
|
||||||
|
|
||||||
|
@ -400,20 +480,25 @@ class ElementwiseLeFloatTensorNanModule(torch.nn.Module):
|
||||||
def ElementwiseLeFloatTensorNanModule_basic(module, tu: TestUtils):
|
def ElementwiseLeFloatTensorNanModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLeIntTensorModule(torch.nn.Module):
|
class ElementwiseLeIntTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.le(x, y)
|
return torch.le(x, y)
|
||||||
|
|
||||||
|
@ -422,18 +507,22 @@ class ElementwiseLeIntTensorModule(torch.nn.Module):
|
||||||
def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.lt(x, y)
|
return torch.lt(x, y)
|
||||||
|
|
||||||
|
@ -442,20 +531,25 @@ class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||||
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||||
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32))
|
torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLtIntTensorModule(torch.nn.Module):
|
class ElementwiseLtIntTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.lt(x, y)
|
return torch.lt(x, y)
|
||||||
|
|
||||||
|
@ -464,17 +558,21 @@ class ElementwiseLtIntTensorModule(torch.nn.Module):
|
||||||
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.eq(x, 6.0)
|
return torch.eq(x, 6.0)
|
||||||
|
|
||||||
|
@ -482,19 +580,24 @@ class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
|
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
|
||||||
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32))
|
torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseEqIntScalarModule(torch.nn.Module):
|
class ElementwiseEqIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.eq(x, 2)
|
return torch.eq(x, 2)
|
||||||
|
|
||||||
|
@ -503,17 +606,21 @@ class ElementwiseEqIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(5, 8, low=2, high=4))
|
module.forward(tu.randint(5, 8, low=2, high=4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseEqBoolScalarModule(torch.nn.Module):
|
class ElementwiseEqBoolScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.bool, True),
|
([-1, -1], torch.bool, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.eq(x, 1)
|
return torch.eq(x, 1)
|
||||||
|
|
||||||
|
@ -525,36 +632,42 @@ def ElementwiseEqBoolScalarModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int32, True),
|
([-1, -1], torch.int32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.eq(x, 2)
|
return torch.eq(x, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
|
||||||
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
|
|
||||||
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32))
|
module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.eq(x, y)
|
return torch.eq(x, y)
|
||||||
|
|
||||||
|
@ -563,20 +676,25 @@ class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
||||||
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, 6.0], [torch.nan, 2.0, 3.1]]).to(torch.float32),
|
torch.tensor([[1.0, 2.2, 6.0], [torch.nan, 2.0, 3.1]]).to(torch.float32),
|
||||||
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
|
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseEqIntTensorModule(torch.nn.Module):
|
class ElementwiseEqIntTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.eq(x, y)
|
return torch.eq(x, y)
|
||||||
|
|
||||||
|
@ -585,17 +703,21 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
|
||||||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ne(x, 2.0)
|
return torch.ne(x, 2.0)
|
||||||
|
|
||||||
|
@ -603,19 +725,24 @@ class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
|
@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule())
|
||||||
def ElementwiseNeFloatScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseNeFloatScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32))
|
torch.tensor([[1.0, 2.2, 2.0], [torch.nan, 2.0, 3.1]]).to(torch.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseNeIntScalarModule(torch.nn.Module):
|
class ElementwiseNeIntScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ne(x, 3)
|
return torch.ne(x, 3)
|
||||||
|
|
||||||
|
@ -624,18 +751,22 @@ class ElementwiseNeIntScalarModule(torch.nn.Module):
|
||||||
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(8, 5, low=2, high=4))
|
module.forward(tu.randint(8, 5, low=2, high=4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseNeFloatTensorModule(torch.nn.Module):
|
class ElementwiseNeFloatTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ne(x, y)
|
return torch.ne(x, y)
|
||||||
|
|
||||||
|
@ -644,20 +775,25 @@ class ElementwiseNeFloatTensorModule(torch.nn.Module):
|
||||||
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||||
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32))
|
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseNeIntTensorModule(torch.nn.Module):
|
class ElementwiseNeIntTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ne(x, y)
|
return torch.ne(x, y)
|
||||||
|
|
||||||
|
@ -666,18 +802,22 @@ class ElementwiseNeIntTensorModule(torch.nn.Module):
|
||||||
def ElementwiseNeIntTensorModule_basic(module, tu: TestUtils):
|
def ElementwiseNeIntTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseNeFloatTensorStaticModule(torch.nn.Module):
|
class ElementwiseNeFloatTensorStaticModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 3], torch.float32, True),
|
([2, 3], torch.float32, True),
|
||||||
([2, 3], torch.float32, True),
|
([2, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ne(x, y)
|
return torch.ne(x, y)
|
||||||
|
|
||||||
|
@ -686,20 +826,25 @@ class ElementwiseNeFloatTensorStaticModule(torch.nn.Module):
|
||||||
def ElementwiseNeFloatTensorStaticModule_basic(module, tu: TestUtils):
|
def ElementwiseNeFloatTensorStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||||
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32))
|
torch.tensor([[1.0, 2.4, 6.0], [torch.nan, 2.0, 6.0]]).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseNeIntTensorStaticModule(torch.nn.Module):
|
class ElementwiseNeIntTensorStaticModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([8, 5], torch.int64, True),
|
([8, 5], torch.int64, True),
|
||||||
([5], torch.int64, True),
|
([5], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ne(x, y)
|
return torch.ne(x, y)
|
||||||
|
|
||||||
|
@ -708,16 +853,20 @@ class ElementwiseNeIntTensorStaticModule(torch.nn.Module):
|
||||||
def ElementwiseNeIntTensorStaticModule_basic(module, tu: TestUtils):
|
def ElementwiseNeIntTensorStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AnyBoolTrueModule(torch.nn.Module):
|
class AnyBoolTrueModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
input = [False, False, True]
|
input = [False, False, True]
|
||||||
return torch.ops.aten.any(input)
|
return torch.ops.aten.any(input)
|
||||||
|
@ -733,9 +882,11 @@ class AnyBoolFalseModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
input = [False, False, False]
|
input = [False, False, False]
|
||||||
return torch.ops.aten.any(input)
|
return torch.ops.aten.any(input)
|
||||||
|
@ -745,17 +896,20 @@ class AnyBoolFalseModule(torch.nn.Module):
|
||||||
def AnyBoolFalseModule_basic(module, tu: TestUtils):
|
def AnyBoolFalseModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# =================================================================================
|
# =================================================================================
|
||||||
|
|
||||||
class AllBoolTrueModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AllBoolTrueModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
input = [True, True, True, True, True]
|
input = [True, True, True, True, True]
|
||||||
return torch.ops.aten.all(input)
|
return torch.ops.aten.all(input)
|
||||||
|
@ -765,36 +919,44 @@ class AllBoolTrueModule(torch.nn.Module):
|
||||||
def AllBoolTrueModule_basic(module, tu: TestUtils):
|
def AllBoolTrueModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# =================================================================================
|
# =================================================================================
|
||||||
|
|
||||||
class AllBoolFalseModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AllBoolFalseModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
input = [True, False, True, True, False]
|
input = [True, False, True, True, False]
|
||||||
return torch.ops.aten.all(input)
|
return torch.ops.aten.all(input)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AllBoolFalseModule())
|
@register_test_case(module_factory=lambda: AllBoolFalseModule())
|
||||||
def AllBoolFalseModule_basic(module, tu: TestUtils):
|
def AllBoolFalseModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseIsnanModule(torch.nn.Module):
|
class ElementwiseIsnanModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.isnan(x)
|
return torch.ops.aten.isnan(x)
|
||||||
|
|
||||||
|
@ -804,17 +966,21 @@ def ElementwiseIsnanModule_basic(module, tu: TestUtils):
|
||||||
x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf])
|
x = torch.tensor([1.0, torch.nan, torch.inf, -torch.inf])
|
||||||
module.forward(x)
|
module.forward(x)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseIsinfModule(torch.nn.Module):
|
class ElementwiseIsinfModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.isinf(x)
|
return torch.ops.aten.isinf(x)
|
||||||
|
|
||||||
|
|
|
@ -11,92 +11,107 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class GridSamplerBasic1(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class GridSamplerBasic1(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([7, 8, 12, 4], torch.float32, True),
|
([7, 8, 12, 4], torch.float32, True),
|
||||||
([7, 11, 13, 2], torch.float32, True)
|
([7, 11, 13, 2], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, g):
|
def forward(self, x, g):
|
||||||
interpolation_mode=0,
|
interpolation_mode = (0,)
|
||||||
padding_mode=0,
|
padding_mode = (0,)
|
||||||
align_corners=True,
|
align_corners = (True,)
|
||||||
tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0],
|
tRes = torch.ops.aten.grid_sampler(
|
||||||
padding_mode[0], align_corners[0])
|
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
|
||||||
|
)
|
||||||
return tRes
|
return tRes
|
||||||
|
|
||||||
@register_test_case(
|
|
||||||
module_factory=lambda: GridSamplerBasic1())
|
@register_test_case(module_factory=lambda: GridSamplerBasic1())
|
||||||
def GridSamplerBasic1_basic(
|
def GridSamplerBasic1_basic(module, tu: TestUtils):
|
||||||
module, tu: TestUtils):
|
inp = torch.rand(7, 8, 12, 4)
|
||||||
inp = torch.rand(7,8,12,4)
|
grd = torch.rand(7, 11, 13, 2) * 2.0 - 1.0
|
||||||
grd = torch.rand(7,11,13,2)*2.0-1.0
|
|
||||||
module.forward(inp, grd)
|
module.forward(inp, grd)
|
||||||
|
|
||||||
|
|
||||||
class GridSamplerBasic2(torch.nn.Module):
|
class GridSamplerBasic2(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
None,
|
[None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)]
|
||||||
([1, 1, 4, 4], torch.float32, True),
|
)
|
||||||
([1, 1, 3, 2], torch.float32, True)
|
|
||||||
])
|
|
||||||
def forward(self, x, g):
|
def forward(self, x, g):
|
||||||
interpolation_mode=0,
|
interpolation_mode = (0,)
|
||||||
padding_mode=0,
|
padding_mode = (0,)
|
||||||
align_corners=True,
|
align_corners = (True,)
|
||||||
tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0],
|
tRes = torch.ops.aten.grid_sampler(
|
||||||
padding_mode[0], align_corners[0])
|
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
|
||||||
|
)
|
||||||
return tRes
|
return tRes
|
||||||
|
|
||||||
@register_test_case(
|
|
||||||
module_factory=lambda: GridSamplerBasic2())
|
@register_test_case(module_factory=lambda: GridSamplerBasic2())
|
||||||
def GridSamplerBasic2_basic(
|
def GridSamplerBasic2_basic(module, tu: TestUtils):
|
||||||
module, tu: TestUtils):
|
inp = torch.tensor(
|
||||||
inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320],
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.4963, 0.7682, 0.0885, 0.1320],
|
||||||
[0.3074, 0.6341, 0.4901, 0.8964],
|
[0.3074, 0.6341, 0.4901, 0.8964],
|
||||||
[0.4556, 0.6323, 0.3489, 0.4017],
|
[0.4556, 0.6323, 0.3489, 0.4017],
|
||||||
[0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor)
|
[0.0223, 0.1689, 0.2939, 0.5185],
|
||||||
grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor)
|
]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
).type(torch.FloatTensor)
|
||||||
|
grd = torch.tensor(
|
||||||
|
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
|
||||||
|
).type(torch.FloatTensor)
|
||||||
module.forward(inp, grd)
|
module.forward(inp, grd)
|
||||||
|
|
||||||
|
|
||||||
class GridSamplerBasic3(torch.nn.Module):
|
class GridSamplerBasic3(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
None,
|
[None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)]
|
||||||
([1, 1, 4, 4], torch.float32, True),
|
)
|
||||||
([1, 1, 3, 2], torch.float32, True)
|
|
||||||
])
|
|
||||||
def forward(self, x, g):
|
def forward(self, x, g):
|
||||||
interpolation_mode=0,
|
interpolation_mode = (0,)
|
||||||
padding_mode=0,
|
padding_mode = (0,)
|
||||||
align_corners=False,
|
align_corners = (False,)
|
||||||
tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0],
|
tRes = torch.ops.aten.grid_sampler(
|
||||||
padding_mode[0], align_corners[0])
|
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
|
||||||
|
)
|
||||||
return tRes
|
return tRes
|
||||||
|
|
||||||
@register_test_case(
|
|
||||||
module_factory=lambda: GridSamplerBasic3())
|
@register_test_case(module_factory=lambda: GridSamplerBasic3())
|
||||||
def GridSamplerBasic3_basic(
|
def GridSamplerBasic3_basic(module, tu: TestUtils):
|
||||||
module, tu: TestUtils):
|
inp = torch.tensor(
|
||||||
inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320],
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.4963, 0.7682, 0.0885, 0.1320],
|
||||||
[0.3074, 0.6341, 0.4901, 0.8964],
|
[0.3074, 0.6341, 0.4901, 0.8964],
|
||||||
[0.4556, 0.6323, 0.3489, 0.4017],
|
[0.4556, 0.6323, 0.3489, 0.4017],
|
||||||
[0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor)
|
[0.0223, 0.1689, 0.2939, 0.5185],
|
||||||
grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor)
|
]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
).type(torch.FloatTensor)
|
||||||
|
grd = torch.tensor(
|
||||||
|
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
|
||||||
|
).type(torch.FloatTensor)
|
||||||
module.forward(inp, grd)
|
module.forward(inp, grd)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ NUM_SEGMENTS = 42
|
||||||
NUM_BINS = 5000
|
NUM_BINS = 5000
|
||||||
NUM_LOGITS = 5000
|
NUM_LOGITS = 5000
|
||||||
|
|
||||||
|
|
||||||
class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -45,31 +46,34 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
||||||
self._iteration = 0
|
self._iteration = 0
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int32, True),
|
([-1], torch.int32, True),
|
||||||
([-1], torch.int32, True),
|
([-1], torch.int32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, segment_value, segment_lengths, logit):
|
def forward(self, segment_value, segment_lengths, logit):
|
||||||
origin_prediction = torch.sigmoid(
|
origin_prediction = torch.sigmoid(logit + torch.log(self.positive_weight))
|
||||||
logit + torch.log(self.positive_weight))
|
|
||||||
# TODO: If in the future this test is removed from xfail for LTC, we will probably hit some device related
|
# TODO: If in the future this test is removed from xfail for LTC, we will probably hit some device related
|
||||||
# issues below when new tensors are created on the CPU, which is currently the default behaviour.
|
# issues below when new tensors are created on the CPU, which is currently the default behaviour.
|
||||||
# The solution would be to move these tensors to ensure they are on the same device as the existing ones.
|
# The solution would be to move these tensors to ensure they are on the same device as the existing ones.
|
||||||
dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32)
|
dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32)
|
||||||
validoffsets = torch.gt(
|
validoffsets = torch.gt(
|
||||||
segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits])
|
segment_lengths[1 : self._num_logits + 1],
|
||||||
|
segment_lengths[0 : self._num_logits],
|
||||||
|
)
|
||||||
gathered_segment_values = (
|
gathered_segment_values = (
|
||||||
segment_value[segment_lengths[0:self._num_logits].long()]+1).int()
|
segment_value[segment_lengths[0 : self._num_logits].long()] + 1
|
||||||
|
).int()
|
||||||
dense_segment_value = torch.where(
|
dense_segment_value = torch.where(
|
||||||
validoffsets, gathered_segment_values, dense_segment_value)
|
validoffsets, gathered_segment_values, dense_segment_value
|
||||||
zeros = torch.empty_like(
|
)
|
||||||
dense_segment_value, dtype=torch.int32).fill_(0)
|
zeros = torch.empty_like(dense_segment_value, dtype=torch.int32).fill_(0)
|
||||||
isnotvalid = torch.gt(dense_segment_value, self._num_segments)
|
isnotvalid = torch.gt(dense_segment_value, self._num_segments)
|
||||||
dense_segment_value = torch.where(
|
dense_segment_value = torch.where(isnotvalid, zeros, dense_segment_value)
|
||||||
isnotvalid, zeros, dense_segment_value)
|
bin_ids_data = torch.ceil(origin_prediction / self.step) - 1
|
||||||
bin_ids_data = torch.ceil(origin_prediction/self.step)-1
|
|
||||||
bin_ids_data = bin_ids_data.long()
|
bin_ids_data = bin_ids_data.long()
|
||||||
curr_segment_value = dense_segment_value * self._num_bins
|
curr_segment_value = dense_segment_value * self._num_bins
|
||||||
bin_ids_data2 = bin_ids_data
|
bin_ids_data2 = bin_ids_data
|
||||||
|
@ -78,12 +82,14 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
||||||
curr_bin_num_examples = self._bin_num_examples[bin_ids_data]
|
curr_bin_num_examples = self._bin_num_examples[bin_ids_data]
|
||||||
curr_segment_value = curr_segment_value / curr_bin_num_examples
|
curr_segment_value = curr_segment_value / curr_bin_num_examples
|
||||||
curr_segment_value = curr_segment_value.float()
|
curr_segment_value = curr_segment_value.float()
|
||||||
curr_segment_value = curr_segment_value * self.bin_ctr_weight_value + \
|
curr_segment_value = (
|
||||||
origin_prediction * self.oneminusbin_ctr_weight_value
|
curr_segment_value * self.bin_ctr_weight_value
|
||||||
isvalid = torch.gt(curr_bin_num_examples,
|
+ origin_prediction * self.oneminusbin_ctr_weight_value
|
||||||
self.bin_ctr_in_use_after)
|
)
|
||||||
|
isvalid = torch.gt(curr_bin_num_examples, self.bin_ctr_in_use_after)
|
||||||
calibrated_prediction_data = torch.where(
|
calibrated_prediction_data = torch.where(
|
||||||
isvalid, curr_segment_value, origin_prediction.float())
|
isvalid, curr_segment_value, origin_prediction.float()
|
||||||
|
)
|
||||||
return calibrated_prediction_data, bin_ids_data
|
return calibrated_prediction_data, bin_ids_data
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,11 +98,11 @@ def HBC_basic(module, tu: TestUtils):
|
||||||
logits = tu.rand(NUM_LOGITS)
|
logits = tu.rand(NUM_LOGITS)
|
||||||
segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int)
|
segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int)
|
||||||
segment_offsets: Tensor = torch.cumsum(segment_lengths, 0)
|
segment_offsets: Tensor = torch.cumsum(segment_lengths, 0)
|
||||||
segment_offsets: Tensor = torch.cat(
|
segment_offsets: Tensor = torch.cat((torch.tensor([0]), segment_offsets), 0)
|
||||||
(torch.tensor([0]), segment_offsets), 0)
|
|
||||||
num_values: int = int(torch.sum(segment_lengths).item())
|
num_values: int = int(torch.sum(segment_lengths).item())
|
||||||
segment_values: Tensor = tu.randint(num_values, high=NUM_SEGMENTS)
|
segment_values: Tensor = tu.randint(num_values, high=NUM_SEGMENTS)
|
||||||
segment_values = torch.cat(
|
segment_values = torch.cat(
|
||||||
(segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0)
|
(segment_values, torch.zeros(NUM_LOGITS - segment_values.numel())), 0
|
||||||
|
)
|
||||||
module.forward(segment_values.int(), segment_offsets.int(), logits)
|
module.forward(segment_values.int(), segment_offsets.int(), logits)
|
||||||
#input shape (5000, 5001, 5000)
|
# input shape (5000, 5001, 5000)
|
||||||
|
|
|
@ -17,15 +17,17 @@ class IndexSelectSingleIdxModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 5, 6], torch.float32, True),
|
([4, 5, 6], torch.float32, True),
|
||||||
([1], torch.int64, True),
|
([1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 1, indices)
|
return torch.index_select(input, 1, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
|
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
|
||||||
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
|
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
||||||
|
@ -36,33 +38,38 @@ class IndexSelectRank0IdxModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 5, 6], torch.float32, True),
|
([4, 5, 6], torch.float32, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 1, indices)
|
return torch.index_select(input, 1, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectRank0IdxModule())
|
@register_test_case(module_factory=lambda: IndexSelectRank0IdxModule())
|
||||||
def IndexSelectRank0IdxModule_basic(module, tu: TestUtils):
|
def IndexSelectRank0IdxModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor(2))
|
module.forward(tu.rand(4, 5, 6), torch.tensor(2))
|
||||||
|
|
||||||
|
|
||||||
class IndexSelectNegativeDimModule(torch.nn.Module):
|
class IndexSelectNegativeDimModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 5, 6], torch.float32, True),
|
([4, 5, 6], torch.float32, True),
|
||||||
([1], torch.int64, True),
|
([1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, -1, indices)
|
return torch.index_select(input, -1, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectNegativeDimModule())
|
@register_test_case(module_factory=lambda: IndexSelectNegativeDimModule())
|
||||||
def IndexSelectNegativeDimModule_basic(module, tu: TestUtils):
|
def IndexSelectNegativeDimModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
||||||
|
@ -73,15 +80,17 @@ class IndexSelectTwoIdxModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 5, 6], torch.float32, True),
|
([4, 5, 6], torch.float32, True),
|
||||||
([2], torch.int64, True),
|
([2], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 2, indices)
|
return torch.index_select(input, 2, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
|
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
|
||||||
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
|
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2, 4]))
|
module.forward(tu.rand(4, 5, 6), torch.tensor([2, 4]))
|
||||||
|
@ -92,15 +101,17 @@ class IndexSelectWholeDimensionModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 5, 6], torch.float32, True),
|
([4, 5, 6], torch.float32, True),
|
||||||
([4], torch.int64, True),
|
([4], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 0, indices)
|
return torch.index_select(input, 0, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
|
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
|
||||||
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
|
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 1, 2, 3]))
|
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 1, 2, 3]))
|
||||||
|
@ -111,15 +122,17 @@ class IndexSelectWholeTensorModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3], torch.float32, True),
|
([3], torch.float32, True),
|
||||||
([3], torch.int64, True),
|
([3], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 0, indices)
|
return torch.index_select(input, 0, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
|
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
|
||||||
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
|
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3), torch.tensor([0, 1, 2]))
|
module.forward(tu.rand(3), torch.tensor([0, 1, 2]))
|
||||||
|
@ -130,15 +143,17 @@ class IndexSelectDynamicModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 2, indices)
|
return torch.index_select(input, 2, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
|
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
|
||||||
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
|
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 4]))
|
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 4]))
|
||||||
|
@ -149,15 +164,17 @@ class IndexSelectDynamicInputSizeModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([2], torch.int64, True),
|
([2], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 2, indices)
|
return torch.index_select(input, 2, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
|
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
|
||||||
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
|
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 2]))
|
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 2]))
|
||||||
|
@ -168,15 +185,17 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 5, 6], torch.float32, True),
|
([4, 5, 6], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input, indices):
|
def forward(self, input, indices):
|
||||||
return torch.index_select(input, 1, indices)
|
return torch.index_select(input, 1, indices)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
|
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
|
||||||
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
|
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), torch.tensor([1, 2]))
|
module.forward(tu.rand(4, 5, 6), torch.tensor([1, 2]))
|
||||||
|
|
|
@ -11,16 +11,19 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MatmulDot(torch.nn.Module):
|
class MatmulDot(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -29,18 +32,22 @@ class MatmulDot(torch.nn.Module):
|
||||||
def Matmul_dot(module, tu: TestUtils):
|
def Matmul_dot(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3), tu.rand(3))
|
module.forward(tu.rand(3), tu.rand(3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class Matmul2D(torch.nn.Module):
|
class Matmul2D(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -49,18 +56,22 @@ class Matmul2D(torch.nn.Module):
|
||||||
def Matmul_2d(module, tu: TestUtils):
|
def Matmul_2d(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4), tu.rand(4, 5))
|
module.forward(tu.rand(3, 4), tu.rand(4, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MatmulVecMat(torch.nn.Module):
|
class MatmulVecMat(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -69,18 +80,22 @@ class MatmulVecMat(torch.nn.Module):
|
||||||
def Matmul_vecmat(module, tu: TestUtils):
|
def Matmul_vecmat(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4), tu.rand(4, 5))
|
module.forward(tu.rand(4), tu.rand(4, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MatmulMatVec(torch.nn.Module):
|
class MatmulMatVec(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -89,18 +104,22 @@ class MatmulMatVec(torch.nn.Module):
|
||||||
def Matmul_matvec(module, tu: TestUtils):
|
def Matmul_matvec(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5), tu.rand(5))
|
module.forward(tu.rand(4, 5), tu.rand(5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class Matmul3D(torch.nn.Module):
|
class Matmul3D(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -109,18 +128,22 @@ class Matmul3D(torch.nn.Module):
|
||||||
def Matmul_3d(module, tu: TestUtils):
|
def Matmul_3d(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class Matmul4d(torch.nn.Module):
|
class Matmul4d(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -129,18 +152,22 @@ class Matmul4d(torch.nn.Module):
|
||||||
def Matmul_4d(module, tu: TestUtils):
|
def Matmul_4d(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class Matmul4dStatic(torch.nn.Module):
|
class Matmul4dStatic(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 5, 6, 7], torch.float32, True),
|
([4, 5, 6, 7], torch.float32, True),
|
||||||
([4, 5, 7, 6], torch.float32, True),
|
([4, 5, 7, 6], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -149,18 +176,22 @@ class Matmul4dStatic(torch.nn.Module):
|
||||||
def Matmul4dStatic_basic(module, tu: TestUtils):
|
def Matmul4dStatic_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MatmulStaticBroadcast(torch.nn.Module):
|
class MatmulStaticBroadcast(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 1, 6, 7], torch.float32, True),
|
([4, 1, 6, 7], torch.float32, True),
|
||||||
([8, 1, 5, 7, 6], torch.float32, True),
|
([8, 1, 5, 7, 6], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -169,18 +200,22 @@ class MatmulStaticBroadcast(torch.nn.Module):
|
||||||
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
|
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
|
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, -1, -1, -1], torch.float32, True),
|
([4, -1, -1, -1], torch.float32, True),
|
||||||
([4, -1, -1, -1], torch.float32, True),
|
([4, -1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -189,18 +224,22 @@ class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
||||||
def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils):
|
def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MatmulBroadcastBatchDim(torch.nn.Module):
|
class MatmulBroadcastBatchDim(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, -1, -1, -1], torch.float32, True),
|
([4, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.matmul(lhs, rhs)
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
@ -209,16 +248,19 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
|
||||||
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class Mv(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class Mv(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, m, v):
|
def forward(self, m, v):
|
||||||
return torch.mv(m, v)
|
return torch.mv(m, v)
|
||||||
|
|
||||||
|
@ -227,16 +269,19 @@ class Mv(torch.nn.Module):
|
||||||
def Mv_basic(module, tu: TestUtils):
|
def Mv_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
module.forward(tu.rand(2, 2), tu.rand(2))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMmFloatTypes(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMmFloatTypes(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.mm(a, b)
|
return torch.ops.aten.mm(a, b)
|
||||||
|
|
||||||
|
@ -245,16 +290,19 @@ class AtenMmFloatTypes(torch.nn.Module):
|
||||||
def AtenMmFloatTypes_basic(module, tu: TestUtils):
|
def AtenMmFloatTypes_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(8, 8), tu.rand(8, 8))
|
module.forward(tu.rand(8, 8), tu.rand(8, 8))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMmIntTypes(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMmIntTypes(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.mm(a, b)
|
return torch.ops.aten.mm(a, b)
|
||||||
|
|
||||||
|
@ -266,17 +314,19 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMmQint8(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMmQint8(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3, 4], torch.int8, True),
|
([3, 4], torch.int8, True),
|
||||||
([4, 3], torch.int8, True),
|
([4, 3], torch.int8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -285,24 +335,30 @@ class AtenMmQint8(torch.nn.Module):
|
||||||
qz = torch.mm(qx, qy)
|
qz = torch.mm(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMmQint8())
|
@register_test_case(module_factory=lambda: AtenMmQint8())
|
||||||
def AtenMmQint8_basic(module, tu: TestUtils):
|
def AtenMmQint8_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(4, 3, low=-128, high=127).to(torch.int8))
|
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(4, 3, low=-128, high=127).to(torch.int8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMmQuint8(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMmQuint8(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3, 4], torch.uint8, True),
|
([3, 4], torch.uint8, True),
|
||||||
([4, 3], torch.uint8, True),
|
([4, 3], torch.uint8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -311,24 +367,30 @@ class AtenMmQuint8(torch.nn.Module):
|
||||||
qz = torch.mm(qx, qy)
|
qz = torch.mm(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
||||||
def AtenMmQuint8_basic(module, tu: TestUtils):
|
def AtenMmQuint8_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=0, high=255).to(torch.uint8),
|
module.forward(
|
||||||
tu.randint(4, 3, low=0, high=255).to(torch.uint8))
|
tu.randint(3, 4, low=0, high=255).to(torch.uint8),
|
||||||
|
tu.randint(4, 3, low=0, high=255).to(torch.uint8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMmQMixedSigni8(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMmQMixedSigni8(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3, 4], torch.int8, True),
|
([3, 4], torch.int8, True),
|
||||||
([4, 3], torch.uint8, True),
|
([4, 3], torch.uint8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -337,24 +399,30 @@ class AtenMmQMixedSigni8(torch.nn.Module):
|
||||||
qz = torch.mm(qx, qy)
|
qz = torch.mm(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
|
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
|
||||||
def AtenMmQMixedSigni8_basic(module, tu: TestUtils):
|
def AtenMmQMixedSigni8_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(4, 3, low=0, high=255).to(torch.uint8))
|
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(4, 3, low=0, high=255).to(torch.uint8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMatmulQint8VM(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMatmulQint8VM(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int8, True),
|
([-1], torch.int8, True),
|
||||||
([-1,-1], torch.int8, True),
|
([-1, -1], torch.int8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -363,23 +431,28 @@ class AtenMatmulQint8VM(torch.nn.Module):
|
||||||
qz = torch.matmul(qx, qy)
|
qz = torch.matmul(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
|
||||||
def AtenMatmulQint8VM_basic(module, tu: TestUtils):
|
def AtenMatmulQint8VM_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(9, 4, low=-128, high=127).to(torch.int8))
|
tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(9, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class AtenMatmulQint8VV(torch.nn.Module):
|
class AtenMatmulQint8VV(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int8, True),
|
([-1], torch.int8, True),
|
||||||
([-1], torch.int8, True),
|
([-1], torch.int8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -388,23 +461,28 @@ class AtenMatmulQint8VV(torch.nn.Module):
|
||||||
qz = torch.matmul(qx, qy)
|
qz = torch.matmul(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
|
||||||
def AtenMatmulQint8VV_basic(module, tu: TestUtils):
|
def AtenMatmulQint8VV_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(9, low=-128, high=127).to(torch.int8))
|
tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class AtenMatmulQint8MV(torch.nn.Module):
|
class AtenMatmulQint8MV(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int8, True),
|
([-1, -1], torch.int8, True),
|
||||||
([-1], torch.int8, True),
|
([-1], torch.int8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -413,23 +491,28 @@ class AtenMatmulQint8MV(torch.nn.Module):
|
||||||
qz = torch.matmul(qx, qy)
|
qz = torch.matmul(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
|
||||||
def AtenMatmulQint8MV_basic(module, tu: TestUtils):
|
def AtenMatmulQint8MV_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(4, low=-128, high=127).to(torch.int8))
|
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(4, low=-128, high=127).to(torch.int8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class AtenMatmulQint8(torch.nn.Module):
|
class AtenMatmulQint8(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, -1, 3, 4], torch.int8, True),
|
([4, -1, 3, 4], torch.int8, True),
|
||||||
([-1, 4, 3], torch.int8, True),
|
([-1, 4, 3], torch.int8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -438,24 +521,30 @@ class AtenMatmulQint8(torch.nn.Module):
|
||||||
qz = torch.matmul(qx, qy)
|
qz = torch.matmul(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8())
|
||||||
def AtenMatmulQint8_basic(module, tu: TestUtils):
|
def AtenMatmulQint8_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, 7, 3, 4, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(7, 4, 3, low=-128, high=127).to(torch.int8))
|
tu.randint(4, 7, 3, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(7, 4, 3, low=-128, high=127).to(torch.int8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMatmulQMixedSigni8(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMatmulQMixedSigni8(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([7, -1, -1, -1], torch.int8, True),
|
([7, -1, -1, -1], torch.int8, True),
|
||||||
([-1, -1, -1], torch.uint8, True),
|
([-1, -1, -1], torch.uint8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -464,24 +553,30 @@ class AtenMatmulQMixedSigni8(torch.nn.Module):
|
||||||
qz = torch.matmul(qx, qy)
|
qz = torch.matmul(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
|
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
|
||||||
def AtenMatmulQMixedSigni8_basic(module, tu: TestUtils):
|
def AtenMatmulQMixedSigni8_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(2, 4, 3, low=0, high=255).to(torch.uint8))
|
tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(2, 4, 3, low=0, high=255).to(torch.uint8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([7, -1, -1, -1], torch.int8, True),
|
([7, -1, -1, -1], torch.int8, True),
|
||||||
([-1, -1, -1], torch.uint8, True),
|
([-1, -1, -1], torch.uint8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||||
qx = torch.dequantize(qx)
|
qx = torch.dequantize(qx)
|
||||||
|
@ -491,21 +586,27 @@ class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
||||||
qz = torch.matmul(qx, qy)
|
qz = torch.matmul(qx, qy)
|
||||||
return qz
|
return qz
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())
|
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())
|
||||||
def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils):
|
def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
module.forward(
|
||||||
tu.randint(2, 6, 4, low=0, high=255).to(torch.uint8))
|
tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(2, 6, 4, low=0, high=255).to(torch.uint8),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenLinalgCrossInt(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenLinalgCrossInt(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 3], torch.int64, True),
|
([2, 3], torch.int64, True),
|
||||||
([2, 3], torch.int64, True),
|
([2, 3], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.linalg_cross(a, b)
|
return torch.ops.aten.linalg_cross(a, b)
|
||||||
|
|
||||||
|
@ -514,16 +615,19 @@ class AtenLinalgCrossInt(torch.nn.Module):
|
||||||
def AtenLinalgCrossInt_basic(module, tu: TestUtils):
|
def AtenLinalgCrossInt_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(2, 3), tu.randint(2, 3))
|
module.forward(tu.randint(2, 3), tu.randint(2, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenLinalgCrossFloat(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenLinalgCrossFloat(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 3], torch.float32, True),
|
([2, 3], torch.float32, True),
|
||||||
([2, 3], torch.float32, True),
|
([2, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.linalg_cross(a, b)
|
return torch.ops.aten.linalg_cross(a, b)
|
||||||
|
|
||||||
|
@ -535,14 +639,16 @@ def AtenLinalgCrossFloat_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenLinalgCrossBroadcast(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenLinalgCrossBroadcast(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 4, 3], torch.float32, True),
|
([1, 4, 3], torch.float32, True),
|
||||||
([5, 4, 3], torch.float32, True),
|
([5, 4, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.linalg_cross(a, b)
|
return torch.ops.aten.linalg_cross(a, b)
|
||||||
|
|
||||||
|
@ -551,16 +657,19 @@ class AtenLinalgCrossBroadcast(torch.nn.Module):
|
||||||
def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils):
|
def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3))
|
module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenLinalgCrossCustomDim(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenLinalgCrossCustomDim(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 4, 3, 2, 2], torch.float32, True),
|
([1, 4, 3, 2, 2], torch.float32, True),
|
||||||
([5, 4, 3, 2, 1], torch.float32, True),
|
([5, 4, 3, 2, 1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.linalg_cross(a, b, dim=2)
|
return torch.ops.aten.linalg_cross(a, b, dim=2)
|
||||||
|
|
||||||
|
@ -569,16 +678,19 @@ class AtenLinalgCrossCustomDim(torch.nn.Module):
|
||||||
def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils):
|
def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class AtenLinalgCrossNegativeDim(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class AtenLinalgCrossNegativeDim(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 4, 3, 2, 2], torch.float32, True),
|
([1, 4, 3, 2, 2], torch.float32, True),
|
||||||
([5, 4, 3, 2, 1], torch.float32, True),
|
([5, 4, 3, 2, 1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.linalg_cross(a, b, dim=-3)
|
return torch.ops.aten.linalg_cross(a, b, dim=-3)
|
||||||
|
|
||||||
|
@ -587,18 +699,22 @@ class AtenLinalgCrossNegativeDim(torch.nn.Module):
|
||||||
def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils):
|
def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenLinalgCrossDynamic(torch.nn.Module):
|
class AtenLinalgCrossDynamic(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.ops.aten.linalg_cross(a, b, dim=1)
|
return torch.ops.aten.linalg_cross(a, b, dim=1)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# Multi-layer perceptron (MLP) models.
|
# Multi-layer perceptron (MLP) models.
|
||||||
|
|
||||||
|
|
||||||
class Mlp1LayerModule(torch.nn.Module):
|
class Mlp1LayerModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -21,18 +22,23 @@ class Mlp1LayerModule(torch.nn.Module):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
self.fc0 = nn.Linear(3, 5)
|
self.fc0 = nn.Linear(3, 5)
|
||||||
self.tanh0 = nn.Tanh()
|
self.tanh0 = nn.Tanh()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.tanh0(self.fc0(x))
|
return self.tanh0(self.fc0(x))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Mlp1LayerModule())
|
@register_test_case(module_factory=lambda: Mlp1LayerModule())
|
||||||
def Mlp1LayerModule_basic(module, tu: TestUtils):
|
def Mlp1LayerModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(5, 3))
|
module.forward(tu.rand(5, 3))
|
||||||
|
|
||||||
|
|
||||||
class Mlp2LayerModule(torch.nn.Module):
|
class Mlp2LayerModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -43,20 +49,25 @@ class Mlp2LayerModule(torch.nn.Module):
|
||||||
self.tanh0 = nn.Tanh()
|
self.tanh0 = nn.Tanh()
|
||||||
self.fc1 = nn.Linear(N_HIDDEN, 2)
|
self.fc1 = nn.Linear(N_HIDDEN, 2)
|
||||||
self.tanh1 = nn.Tanh()
|
self.tanh1 = nn.Tanh()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.tanh0(self.fc0(x))
|
x = self.tanh0(self.fc0(x))
|
||||||
x = self.tanh1(self.fc1(x))
|
x = self.tanh1(self.fc1(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Mlp2LayerModule())
|
@register_test_case(module_factory=lambda: Mlp2LayerModule())
|
||||||
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(5, 3))
|
module.forward(tu.rand(5, 3))
|
||||||
|
|
||||||
|
|
||||||
class Mlp2LayerModuleNoBias(torch.nn.Module):
|
class Mlp2LayerModuleNoBias(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -67,20 +78,25 @@ class Mlp2LayerModuleNoBias(torch.nn.Module):
|
||||||
self.tanh0 = nn.Tanh()
|
self.tanh0 = nn.Tanh()
|
||||||
self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False)
|
self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False)
|
||||||
self.tanh1 = nn.Tanh()
|
self.tanh1 = nn.Tanh()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.tanh0(self.fc0(x))
|
x = self.tanh0(self.fc0(x))
|
||||||
x = self.tanh1(self.fc1(x))
|
x = self.tanh1(self.fc1(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias())
|
@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias())
|
||||||
def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils):
|
def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(5, 3))
|
module.forward(tu.rand(5, 3))
|
||||||
|
|
||||||
|
|
||||||
class BatchMlpLayerModule(torch.nn.Module):
|
class BatchMlpLayerModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -88,14 +104,18 @@ class BatchMlpLayerModule(torch.nn.Module):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
self.fc0 = nn.Linear(3, 5)
|
self.fc0 = nn.Linear(3, 5)
|
||||||
self.tanh0 = nn.Tanh()
|
self.tanh0 = nn.Tanh()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.tanh0(self.fc0(x))
|
return self.tanh0(self.fc0(x))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
|
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
|
||||||
def BatchMlpLayerModule_basic(module, tu: TestUtils):
|
def BatchMlpLayerModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(7, 5, 3))
|
module.forward(tu.rand(7, 5, 3))
|
||||||
|
|
|
@ -13,23 +13,22 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule(torch.nn.Module):
|
class NllLossModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
# Here the 2nd index is ignored.
|
# Here the 2nd index is ignored.
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ops.aten.nll_loss_forward(x,
|
return torch.ops.aten.nll_loss_forward(
|
||||||
target=y,
|
x, target=y, weight=None, reduction=0, ignore_index=2
|
||||||
weight=None,
|
)
|
||||||
reduction=0,
|
|
||||||
ignore_index=2)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule())
|
@register_test_case(module_factory=lambda: NllLossModule())
|
||||||
|
@ -42,18 +41,18 @@ class NllLossModule_mean(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
# Here the 2nd index is ignored.
|
# Here the 2nd index is ignored.
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ops.aten.nll_loss_forward(x,
|
return torch.ops.aten.nll_loss_forward(
|
||||||
target=y,
|
x, target=y, weight=None, reduction=1, ignore_index=2
|
||||||
weight=None,
|
)
|
||||||
reduction=1,
|
|
||||||
ignore_index=2)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
||||||
|
@ -66,18 +65,18 @@ class NllLossModule_sum(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
# Here the 2nd index is ignored.
|
# Here the 2nd index is ignored.
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ops.aten.nll_loss_forward(x,
|
return torch.ops.aten.nll_loss_forward(
|
||||||
target=y,
|
x, target=y, weight=None, reduction=2, ignore_index=2
|
||||||
weight=None,
|
)
|
||||||
reduction=2,
|
|
||||||
ignore_index=2)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
||||||
|
@ -90,18 +89,18 @@ class NllLossModule_1D(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
# Here the 2nd index is ignored.
|
# Here the 2nd index is ignored.
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ops.aten.nll_loss_forward(x,
|
return torch.ops.aten.nll_loss_forward(
|
||||||
target=y,
|
x, target=y, weight=None, reduction=0, ignore_index=2
|
||||||
weight=None,
|
)
|
||||||
reduction=0,
|
|
||||||
ignore_index=2)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
||||||
|
@ -110,409 +109,465 @@ def NllLossModule_1D_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
# None of the index is ignored here, since the ignored index is out of bounds.
|
# None of the index is ignored here, since the ignored index is out of bounds.
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ops.aten.nll_loss_forward(x,
|
return torch.ops.aten.nll_loss_forward(
|
||||||
target=y,
|
x, target=y, weight=None, reduction=0, ignore_index=10
|
||||||
weight=None,
|
)
|
||||||
reduction=0,
|
|
||||||
ignore_index=10)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||||
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
|
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||||
|
|
||||||
class NllLossModule_backward(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class NllLossModule_backward(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, total_weight):
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=None,
|
weight=None,
|
||||||
reduction=0,
|
reduction=0,
|
||||||
ignore_index=10,
|
ignore_index=10,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backward())
|
@register_test_case(module_factory=lambda: NllLossModule_backward())
|
||||||
def NllLossModuleBackward_basic(module, tu: TestUtils):
|
def NllLossModuleBackward_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
torch.tensor(3.))
|
tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backwardWeight(torch.nn.Module):
|
class NllLossModule_backwardWeight(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, weight, total_weight):
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
reduction=0,
|
reduction=0,
|
||||||
ignore_index=10,
|
ignore_index=10,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
|
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
|
||||||
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
|
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
tu.rand(4), torch.tensor(3.))
|
tu.rand(3),
|
||||||
|
tu.rand(3, 4),
|
||||||
|
torch.tensor([2, 3, 0]),
|
||||||
|
tu.rand(4),
|
||||||
|
torch.tensor(3.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, total_weight):
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=None,
|
weight=None,
|
||||||
reduction=0,
|
reduction=0,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: NllLossModule_backward_ignore_index())
|
||||||
module_factory=lambda: NllLossModule_backward_ignore_index())
|
|
||||||
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
torch.tensor(3.))
|
tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backwardMean(torch.nn.Module):
|
class NllLossModule_backwardMean(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, total_weight):
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=None,
|
weight=None,
|
||||||
reduction=1,
|
reduction=1,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMean())
|
@register_test_case(module_factory=lambda: NllLossModule_backwardMean())
|
||||||
def NllLossModuleBackwardMean_basic(module, tu: TestUtils):
|
def NllLossModuleBackwardMean_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
torch.tensor(3.))
|
tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, weight, total_weight):
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
reduction=1,
|
reduction=1,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
|
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
|
||||||
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
|
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
tu.rand(4), torch.tensor(3.))
|
tu.rand(1),
|
||||||
|
tu.rand(3, 4),
|
||||||
|
torch.tensor([2, 3, 0]),
|
||||||
|
tu.rand(4),
|
||||||
|
torch.tensor(3.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backwardSum(torch.nn.Module):
|
class NllLossModule_backwardSum(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, total_weight):
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=None,
|
weight=None,
|
||||||
reduction=2,
|
reduction=2,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSum())
|
@register_test_case(module_factory=lambda: NllLossModule_backwardSum())
|
||||||
def NllLossModuleBackwardSum_basic(module, tu: TestUtils):
|
def NllLossModuleBackwardSum_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
torch.tensor(3.))
|
tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backwardSumWeight(torch.nn.Module):
|
class NllLossModule_backwardSumWeight(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, weight, total_weight):
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
reduction=2,
|
reduction=2,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
|
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
|
||||||
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
|
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
tu.rand(4), torch.tensor(3.))
|
tu.rand(1),
|
||||||
|
tu.rand(3, 4),
|
||||||
|
torch.tensor([2, 3, 0]),
|
||||||
|
tu.rand(4),
|
||||||
|
torch.tensor(3.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward1D(torch.nn.Module):
|
class NllLossModule_backward1D(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, total_weight):
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=None,
|
weight=None,
|
||||||
reduction=0,
|
reduction=0,
|
||||||
ignore_index=10,
|
ignore_index=10,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backward1D())
|
@register_test_case(module_factory=lambda: NllLossModule_backward1D())
|
||||||
def NllLossModuleBackward1D_basic(module, tu: TestUtils):
|
def NllLossModuleBackward1D_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0))
|
||||||
torch.tensor(3.))
|
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward1DWeight(torch.nn.Module):
|
class NllLossModule_backward1DWeight(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, weight, total_weight):
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
reduction=0,
|
reduction=0,
|
||||||
ignore_index=10,
|
ignore_index=10,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
|
||||||
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
|
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
tu.rand(3), torch.tensor(3.))
|
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward1DMean(torch.nn.Module):
|
class NllLossModule_backward1DMean(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, total_weight):
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=None,
|
weight=None,
|
||||||
reduction=1,
|
reduction=1,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMean())
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DMean())
|
||||||
def NllLossModuleBackward1DMean_basic(module, tu: TestUtils):
|
def NllLossModuleBackward1DMean_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0))
|
||||||
torch.tensor(3.))
|
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, weight, total_weight):
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
reduction=1,
|
reduction=1,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
||||||
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
|
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
tu.rand(3), torch.tensor(3.))
|
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward1DSum(torch.nn.Module):
|
class NllLossModule_backward1DSum(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, total_weight):
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=None,
|
weight=None,
|
||||||
reduction=2,
|
reduction=2,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSum())
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DSum())
|
||||||
def NllLossModuleBackward1DSum_basic(module, tu: TestUtils):
|
def NllLossModuleBackward1DSum_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), torch.tensor(3.0))
|
||||||
torch.tensor(3.))
|
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_output, input, target, weight, total_weight):
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
return torch.ops.aten.nll_loss_backward(
|
||||||
|
grad_output,
|
||||||
input,
|
input,
|
||||||
target=target,
|
target=target,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
reduction=2,
|
reduction=2,
|
||||||
ignore_index=1,
|
ignore_index=1,
|
||||||
total_weight=total_weight)
|
total_weight=total_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
|
||||||
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
|
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
module.forward(
|
||||||
tu.rand(3), torch.tensor(3.))
|
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
|
||||||
|
)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm1DModule(torch.nn.Module):
|
class BatchNorm1DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -18,15 +19,16 @@ class BatchNorm1DModule(torch.nn.Module):
|
||||||
self.bn1d.eval()
|
self.bn1d.eval()
|
||||||
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||||
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||||
self.bn1d.weight = torch.nn.Parameter(
|
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||||
torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
|
||||||
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([10, 4, 3], torch.float32, True),
|
([10, 4, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn1d(x)
|
return self.bn1d(x)
|
||||||
|
|
||||||
|
@ -35,8 +37,10 @@ class BatchNorm1DModule(torch.nn.Module):
|
||||||
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(10, 4, 3))
|
module.forward(tu.rand(10, 4, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -44,15 +48,16 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
||||||
self.bn1d.eval()
|
self.bn1d.eval()
|
||||||
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||||
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||||
self.bn1d.weight = torch.nn.Parameter(
|
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||||
torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
|
||||||
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([10, 4], torch.float32, True),
|
([10, 4], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn1d(x)
|
return self.bn1d(x)
|
||||||
|
|
||||||
|
@ -61,8 +66,10 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
||||||
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
|
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(10, 4))
|
module.forward(tu.rand(10, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm2DModule(torch.nn.Module):
|
class BatchNorm2DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -74,10 +81,12 @@ class BatchNorm2DModule(torch.nn.Module):
|
||||||
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
|
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([10, 2, 3, 3], torch.float32, True),
|
([10, 2, 3, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn2d(x)
|
return self.bn2d(x)
|
||||||
|
|
||||||
|
@ -86,8 +95,10 @@ class BatchNorm2DModule(torch.nn.Module):
|
||||||
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(10, 2, 3, 3))
|
module.forward(tu.rand(10, 2, 3, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm3DModule(torch.nn.Module):
|
class BatchNorm3DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -95,16 +106,16 @@ class BatchNorm3DModule(torch.nn.Module):
|
||||||
self.bn3d.eval()
|
self.bn3d.eval()
|
||||||
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
|
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
|
||||||
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
|
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
|
||||||
self.bn3d.weight = torch.nn.Parameter(
|
self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
||||||
torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
||||||
self.bn3d.bias = torch.nn.Parameter(
|
|
||||||
torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 5, 3, 6, 4], torch.float32, True),
|
([2, 5, 3, 6, 4], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn3d(x)
|
return self.bn3d(x)
|
||||||
|
|
||||||
|
@ -113,274 +124,361 @@ class BatchNorm3DModule(torch.nn.Module):
|
||||||
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 3, 6, 4))
|
module.forward(tu.rand(2, 5, 3, 6, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm1DStaticShapeModule(torch.nn.Module):
|
class BatchNorm1DStaticShapeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 5], torch.float32, True),
|
([2, 5], torch.float32, True),
|
||||||
([5], torch.float32, True),
|
([5], torch.float32, True),
|
||||||
([5], torch.float32, True),
|
([5], torch.float32, True),
|
||||||
([5], torch.float32, True),
|
([5], torch.float32, True),
|
||||||
([5], torch.float32, True),
|
([5], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias, running_mean, running_var):
|
def forward(self, x, weight, bias, running_mean, running_var):
|
||||||
return torch.ops.aten.batch_norm(
|
return torch.ops.aten.batch_norm(
|
||||||
x, weight, bias, running_mean, running_var, training=False,
|
x,
|
||||||
momentum=0.1, eps=0.00001, cudnn_enabled=False)
|
weight,
|
||||||
|
bias,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
training=False,
|
||||||
|
momentum=0.1,
|
||||||
|
eps=0.00001,
|
||||||
|
cudnn_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BatchNorm1DStaticShapeModule())
|
@register_test_case(module_factory=lambda: BatchNorm1DStaticShapeModule())
|
||||||
def BatchNorm1DStaticShapeModule_basic(module, tu: TestUtils):
|
def BatchNorm1DStaticShapeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||||
tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeBatchNorm1DModule(torch.nn.Module):
|
class NativeBatchNorm1DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias, running_mean, running_var):
|
def forward(self, x, weight, bias, running_mean, running_var):
|
||||||
return torch.ops.aten.native_batch_norm(
|
return torch.ops.aten.native_batch_norm(
|
||||||
x, weight, bias, running_mean, running_var, training=False,
|
x,
|
||||||
momentum=0.1, eps=0.00001)
|
weight,
|
||||||
|
bias,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
training=False,
|
||||||
|
momentum=0.1,
|
||||||
|
eps=0.00001,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeBatchNorm1DModule())
|
@register_test_case(module_factory=lambda: NativeBatchNorm1DModule())
|
||||||
def NativeBatchNorm1DModule_basic(module, tu: TestUtils):
|
def NativeBatchNorm1DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||||
tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeBatchNorm2DModule(torch.nn.Module):
|
class NativeBatchNorm2DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias, running_mean, running_var):
|
def forward(self, x, weight, bias, running_mean, running_var):
|
||||||
return torch.ops.aten.native_batch_norm(
|
return torch.ops.aten.native_batch_norm(
|
||||||
x, weight, bias, running_mean, running_var, training=False,
|
x,
|
||||||
momentum=0.1, eps=0.00001)
|
weight,
|
||||||
|
bias,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
training=False,
|
||||||
|
momentum=0.1,
|
||||||
|
eps=0.00001,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeBatchNorm2DModule())
|
@register_test_case(module_factory=lambda: NativeBatchNorm2DModule())
|
||||||
def NativeBatchNorm2DModule_basic(module, tu: TestUtils):
|
def NativeBatchNorm2DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||||
tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeBatchNorm3DModule(torch.nn.Module):
|
class NativeBatchNorm3DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias, running_mean, running_var):
|
def forward(self, x, weight, bias, running_mean, running_var):
|
||||||
return torch.ops.aten.native_batch_norm(
|
return torch.ops.aten.native_batch_norm(
|
||||||
x, weight, bias, running_mean, running_var, training=False,
|
x,
|
||||||
momentum=0.1, eps=0.00001)
|
weight,
|
||||||
|
bias,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
training=False,
|
||||||
|
momentum=0.1,
|
||||||
|
eps=0.00001,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeBatchNorm3DModule())
|
@register_test_case(module_factory=lambda: NativeBatchNorm3DModule())
|
||||||
def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
|
def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeBatchNormNoneWeightModule(torch.nn.Module):
|
class NativeBatchNormNoneWeightModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, bias, running_mean, running_var):
|
def forward(self, x, bias, running_mean, running_var):
|
||||||
return torch.ops.aten.native_batch_norm(
|
return torch.ops.aten.native_batch_norm(
|
||||||
x, None, bias, running_mean, running_var, training=False,
|
x,
|
||||||
momentum=0.1, eps=0.00001)
|
None,
|
||||||
|
bias,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
training=False,
|
||||||
|
momentum=0.1,
|
||||||
|
eps=0.00001,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
|
@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
|
||||||
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
|
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
|
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class GroupNormModule(torch.nn.Module):
|
class GroupNormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 4, 6, 7], torch.float32, True),
|
([2, 4, 6, 7], torch.float32, True),
|
||||||
([4], torch.float32, True),
|
([4], torch.float32, True),
|
||||||
([4], torch.float32, True),
|
([4], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias):
|
def forward(self, x, weight, bias):
|
||||||
return torch.ops.aten.group_norm(x, 2, weight, bias, 1.0000000000000001e-05, False)
|
return torch.ops.aten.group_norm(
|
||||||
|
x, 2, weight, bias, 1.0000000000000001e-05, False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: GroupNormModule())
|
@register_test_case(module_factory=lambda: GroupNormModule())
|
||||||
def GroupNormModule_basic(module, tu: TestUtils):
|
def GroupNormModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4))
|
module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4))
|
||||||
|
|
||||||
|
|
||||||
class GroupNormNoWeightAndBiasModule(torch.nn.Module):
|
class GroupNormNoWeightAndBiasModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 4, 6, 7], torch.float32, True),
|
([2, 4, 6, 7], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.group_norm(x, 2, None, None, 1.0000000000000001e-05, False)
|
return torch.ops.aten.group_norm(
|
||||||
|
x, 2, None, None, 1.0000000000000001e-05, False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule())
|
@register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule())
|
||||||
def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils):
|
def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 4, 6, 7))
|
module.forward(tu.rand(2, 4, 6, 7))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeGroupNormModule(torch.nn.Module):
|
class NativeGroupNormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 6, 2, 2], torch.float32, True),
|
([2, 6, 2, 2], torch.float32, True),
|
||||||
([6], torch.float32, True),
|
([6], torch.float32, True),
|
||||||
([6], torch.float32, True),
|
([6], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias):
|
def forward(self, x, weight, bias):
|
||||||
return torch.ops.aten.native_group_norm(
|
return torch.ops.aten.native_group_norm(x, weight, bias, 2, 6, 4, 3, 0.000001)
|
||||||
x, weight, bias,
|
|
||||||
2, 6, 4, 3, 0.000001)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeGroupNormModule())
|
@register_test_case(module_factory=lambda: NativeGroupNormModule())
|
||||||
def NativeGroupNormModule_basic(module, tu: TestUtils):
|
def NativeGroupNormModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
|
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeGroupNormBackwardModule(torch.nn.Module):
|
class NativeGroupNormBackwardModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 6, 2, 2], torch.float32, True),
|
([2, 6, 2, 2], torch.float32, True),
|
||||||
([2, 6, 2, 2], torch.float32, True),
|
([2, 6, 2, 2], torch.float32, True),
|
||||||
([2, 3], torch.float32, True),
|
([2, 3], torch.float32, True),
|
||||||
([2, 3], torch.float32, True),
|
([2, 3], torch.float32, True),
|
||||||
([6], torch.float32, True),
|
([6], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad_out, x, mean, rstd, weight):
|
def forward(self, grad_out, x, mean, rstd, weight):
|
||||||
return torch.ops.aten.native_group_norm_backward(
|
return torch.ops.aten.native_group_norm_backward(
|
||||||
grad_out, x, mean, rstd, weight,
|
grad_out, x, mean, rstd, weight, 2, 6, 4, 3, [True, True, True]
|
||||||
2, 6, 4, 3, [True, True, True])
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
|
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
|
||||||
def NativeGroupNormBackwardModule_basic(module, tu: TestUtils):
|
def NativeGroupNormBackwardModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 6, 2, 2), tu.rand(2, 6, 2, 2), tu.rand(2, 3),
|
module.forward(
|
||||||
tu.rand(2, 3), tu.rand(6))
|
tu.rand(2, 6, 2, 2),
|
||||||
|
tu.rand(2, 6, 2, 2),
|
||||||
|
tu.rand(2, 3),
|
||||||
|
tu.rand(2, 3),
|
||||||
|
tu.rand(6),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeLayerNormModule(torch.nn.Module):
|
class NativeLayerNormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 5, 2, 2, 3], torch.float32, True),
|
([2, 5, 2, 2, 3], torch.float32, True),
|
||||||
([2, 2, 3], torch.float32, True),
|
([2, 2, 3], torch.float32, True),
|
||||||
([2, 2, 3], torch.float32, True),
|
([2, 2, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias):
|
def forward(self, x, weight, bias):
|
||||||
list = [2, 2, 3]
|
list = [2, 2, 3]
|
||||||
return torch.ops.aten.native_layer_norm(
|
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
|
||||||
x, list, weight, bias, eps=0.5)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
||||||
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
class NativeLayerNormDynamicModule(torch.nn.Module):
|
class NativeLayerNormDynamicModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias):
|
def forward(self, x, weight, bias):
|
||||||
list = [2, 2, 3]
|
list = [2, 2, 3]
|
||||||
return torch.ops.aten.native_layer_norm(
|
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
|
||||||
x, list, weight, bias, eps=0.5)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
|
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
|
||||||
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
|
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NormalizeModule(torch.nn.Module):
|
class NormalizeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([3, 3], torch.float32, True),
|
([3, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.nn.functional.normalize(x)
|
return torch.nn.functional.normalize(x)
|
||||||
|
|
||||||
|
@ -389,48 +487,59 @@ class NormalizeModule(torch.nn.Module):
|
||||||
def NormalizeModule_basic(module, tu: TestUtils):
|
def NormalizeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 3))
|
module.forward(tu.rand(3, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NativeLayerNormModule4D(torch.nn.Module):
|
class NativeLayerNormModule4D(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([5, 2, 2, 3], torch.float32, True),
|
([5, 2, 2, 3], torch.float32, True),
|
||||||
([2, 2, 3], torch.float32, True),
|
([2, 2, 3], torch.float32, True),
|
||||||
([2, 2, 3], torch.float32, True),
|
([2, 2, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, weight, bias):
|
def forward(self, x, weight, bias):
|
||||||
list = [2, 2, 3]
|
list = [2, 2, 3]
|
||||||
return torch.ops.aten.native_layer_norm(
|
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)[0]
|
||||||
x, list, weight, bias, eps=0.5)[0]
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeLayerNormModule4D())
|
@register_test_case(module_factory=lambda: NativeLayerNormModule4D())
|
||||||
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
|
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class LayerNormModule(torch.nn.Module):
|
class LayerNormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||||
self.ly.eval()
|
self.ly.eval()
|
||||||
self.ly.weight = torch.nn.Parameter(
|
self.ly.weight = torch.nn.Parameter(
|
||||||
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
torch.tensor(
|
||||||
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
[[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]
|
||||||
|
)
|
||||||
|
)
|
||||||
self.ly.bias = torch.nn.Parameter(
|
self.ly.bias = torch.nn.Parameter(
|
||||||
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
torch.tensor(
|
||||||
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
[[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 5, 2, 2, 3], torch.float32, True),
|
([2, 5, 2, 2, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.ly(x)
|
return self.ly(x)
|
||||||
|
|
||||||
|
@ -439,8 +548,10 @@ class LayerNormModule(torch.nn.Module):
|
||||||
def LayerNormModule_basic(module, tu: TestUtils):
|
def LayerNormModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class LayerNormLastDimModule(torch.nn.Module):
|
class LayerNormLastDimModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -450,10 +561,12 @@ class LayerNormLastDimModule(torch.nn.Module):
|
||||||
self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3]))
|
self.ly.bias = torch.nn.Parameter(torch.tensor([0.2, 0.4, 0.3]))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 5, 2, 2, 3], torch.float32, True),
|
([2, 5, 2, 2, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.ly(x)
|
return self.ly(x)
|
||||||
|
|
||||||
|
@ -462,25 +575,33 @@ class LayerNormLastDimModule(torch.nn.Module):
|
||||||
def LayerNormLastDimModule_basic(module, tu: TestUtils):
|
def LayerNormLastDimModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||||
self.ly.eval()
|
self.ly.eval()
|
||||||
self.ly.weight = torch.nn.Parameter(
|
self.ly.weight = torch.nn.Parameter(
|
||||||
torch.tensor([[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]],
|
torch.tensor(
|
||||||
[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]))
|
[[[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]], [[3.0, 2.0, 4.0], [2.0, 3.0, 3.0]]]
|
||||||
|
)
|
||||||
|
)
|
||||||
self.ly.bias = torch.nn.Parameter(
|
self.ly.bias = torch.nn.Parameter(
|
||||||
torch.tensor([[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]],
|
torch.tensor(
|
||||||
[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]))
|
[[[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]], [[0.5, 0.4, 0.3], [0.2, 0.4, 0.3]]]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 2, 3], torch.float32, True),
|
([2, 2, 3], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.ly(x)
|
return self.ly(x)
|
||||||
|
|
||||||
|
@ -489,20 +610,25 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||||
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 2, 3))
|
module.forward(tu.rand(2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
class AtenInstanceNormModule(torch.nn.Module):
|
class AtenInstanceNormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 2, 1, 3], torch.float32, True),
|
([1, 2, 1, 3], torch.float32, True),
|
||||||
([2], torch.float32, True),
|
([2], torch.float32, True),
|
||||||
([2], torch.float32, True)
|
([2], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, w, b):
|
def forward(self, x, w, b):
|
||||||
return torch.ops.aten.instance_norm(x, w, b, None,
|
return torch.ops.aten.instance_norm(
|
||||||
None, True, 0.0, 1e-05, False)
|
x, w, b, None, None, True, 0.0, 1e-05, False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenInstanceNormModule())
|
@register_test_case(module_factory=lambda: AtenInstanceNormModule())
|
||||||
def AtenInstanceNormModule_basic(module, tu: TestUtils):
|
def AtenInstanceNormModule_basic(module, tu: TestUtils):
|
||||||
|
|
|
@ -12,98 +12,112 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class ReflectionPad2dModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class ReflectionPad2dModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 20, 20], torch.float32, True),
|
([1, 20, 20], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.reflection_pad2d(x, (10,10,10,10))
|
return torch.ops.aten.reflection_pad2d(x, (10, 10, 10, 10))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ReflectionPad2dModule())
|
@register_test_case(module_factory=lambda: ReflectionPad2dModule())
|
||||||
def ReflectionPad2dModule_basic(module, tu: TestUtils):
|
def ReflectionPad2dModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 20, 20, low=-1))
|
module.forward(tu.rand(1, 20, 20, low=-1))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class ReflectionPad2dModuleTop(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class ReflectionPad2dModuleTop(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 3, 4], torch.float32, True),
|
([1, 3, 4], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.reflection_pad2d(x, (0,0,2,0))
|
return torch.ops.aten.reflection_pad2d(x, (0, 0, 2, 0))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop())
|
@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop())
|
||||||
def ReflectionPad2dModule_Top(module, tu: TestUtils):
|
def ReflectionPad2dModule_Top(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 3, 4))
|
module.forward(tu.rand(1, 3, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class ReflectionPad2dModuleBottom(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class ReflectionPad2dModuleBottom(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 3, 10, 10], torch.float32, True),
|
([2, 3, 10, 10], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.reflection_pad2d(x, (0,0,0,5))
|
return torch.ops.aten.reflection_pad2d(x, (0, 0, 0, 5))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom())
|
@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom())
|
||||||
def ReflectionPad2dModule_Bottom(module, tu: TestUtils):
|
def ReflectionPad2dModule_Bottom(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3, 10, 10))
|
module.forward(tu.rand(2, 3, 10, 10))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class ReflectionPad2dModuleLeft(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class ReflectionPad2dModuleLeft(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 3, 20, 20], torch.float32, True),
|
([2, 3, 20, 20], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.reflection_pad2d(x, (15,0,0,0))
|
return torch.ops.aten.reflection_pad2d(x, (15, 0, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft())
|
@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft())
|
||||||
def ReflectionPad2dModule_Left(module, tu: TestUtils):
|
def ReflectionPad2dModule_Left(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3, 20, 20))
|
module.forward(tu.rand(2, 3, 20, 20))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class ReflectionPad2dModuleRight(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class ReflectionPad2dModuleRight(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([2, 3, 20, 20], torch.float32, True),
|
([2, 3, 20, 20], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.reflection_pad2d(x, (0,11,0,0))
|
return torch.ops.aten.reflection_pad2d(x, (0, 11, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
|
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -12,12 +12,15 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_quant_model_input():
|
def get_quant_model_input():
|
||||||
return 2 * torch.rand((1, 16)) - 1
|
return 2 * torch.rand((1, 16)) - 1
|
||||||
|
|
||||||
|
|
||||||
def get_batched_quant_model_input():
|
def get_batched_quant_model_input():
|
||||||
return 2 * torch.rand((1, 2, 16)) - 1
|
return 2 * torch.rand((1, 2, 16)) - 1
|
||||||
|
|
||||||
|
|
||||||
class QuantizedNoLayer(nn.Module):
|
class QuantizedNoLayer(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -26,15 +29,18 @@ class QuantizedNoLayer(nn.Module):
|
||||||
self.dequantize = torch.quantization.DeQuantStub()
|
self.dequantize = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 16], torch.float32, True),
|
([1, 16], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.quantize(x)
|
x = self.quantize(x)
|
||||||
x = self.dequantize(x)
|
x = self.dequantize(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_quantized_no_layer():
|
def get_quantized_no_layer():
|
||||||
model = QuantizedNoLayer()
|
model = QuantizedNoLayer()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -46,10 +52,12 @@ def get_quantized_no_layer():
|
||||||
torch.quantization.convert(model, inplace=True)
|
torch.quantization.convert(model, inplace=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=get_quantized_no_layer)
|
@register_test_case(module_factory=get_quantized_no_layer)
|
||||||
def QuantizedNoLayer_basic(module, tu: TestUtils):
|
def QuantizedNoLayer_basic(module, tu: TestUtils):
|
||||||
module.forward(get_quant_model_input())
|
module.forward(get_quant_model_input())
|
||||||
|
|
||||||
|
|
||||||
class QuantizedSingleLayer(nn.Module):
|
class QuantizedSingleLayer(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -61,16 +69,19 @@ class QuantizedSingleLayer(nn.Module):
|
||||||
self.dequantize = torch.quantization.DeQuantStub()
|
self.dequantize = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 16], torch.float32, True),
|
([1, 16], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.quantize(x)
|
x = self.quantize(x)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
x = self.dequantize(x)
|
x = self.dequantize(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_quantized_single_layer():
|
def get_quantized_single_layer():
|
||||||
model = QuantizedSingleLayer()
|
model = QuantizedSingleLayer()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -82,10 +93,12 @@ def get_quantized_single_layer():
|
||||||
torch.quantization.convert(model, inplace=True)
|
torch.quantization.convert(model, inplace=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=get_quantized_single_layer)
|
@register_test_case(module_factory=get_quantized_single_layer)
|
||||||
def QuantizedSingleLayer_basic(module, tu: TestUtils):
|
def QuantizedSingleLayer_basic(module, tu: TestUtils):
|
||||||
module.forward(get_quant_model_input())
|
module.forward(get_quant_model_input())
|
||||||
|
|
||||||
|
|
||||||
class QuantizedBatchedInputSingleLayer(nn.Module):
|
class QuantizedBatchedInputSingleLayer(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -97,16 +110,19 @@ class QuantizedBatchedInputSingleLayer(nn.Module):
|
||||||
self.dequantize = torch.quantization.DeQuantStub()
|
self.dequantize = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 2, 16], torch.float32, True),
|
([1, 2, 16], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.quantize(x)
|
x = self.quantize(x)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
x = self.dequantize(x)
|
x = self.dequantize(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_batched_quantized_single_layer():
|
def get_batched_quantized_single_layer():
|
||||||
model = QuantizedBatchedInputSingleLayer()
|
model = QuantizedBatchedInputSingleLayer()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -118,10 +134,12 @@ def get_batched_quantized_single_layer():
|
||||||
torch.quantization.convert(model, inplace=True)
|
torch.quantization.convert(model, inplace=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=get_batched_quantized_single_layer)
|
@register_test_case(module_factory=get_batched_quantized_single_layer)
|
||||||
def QuantizedBatchedInputSingleLayer_basic(module, tu: TestUtils):
|
def QuantizedBatchedInputSingleLayer_basic(module, tu: TestUtils):
|
||||||
module.forward(get_batched_quant_model_input())
|
module.forward(get_batched_quant_model_input())
|
||||||
|
|
||||||
|
|
||||||
class QuantizedMLP(nn.Module):
|
class QuantizedMLP(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -135,16 +153,19 @@ class QuantizedMLP(nn.Module):
|
||||||
self.dequantize = torch.quantization.DeQuantStub()
|
self.dequantize = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 16], torch.float32, True),
|
([1, 16], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.quantize(x)
|
x = self.quantize(x)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
x = self.dequantize(x)
|
x = self.dequantize(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_quantized_mlp():
|
def get_quantized_mlp():
|
||||||
model = QuantizedMLP()
|
model = QuantizedMLP()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -156,6 +177,7 @@ def get_quantized_mlp():
|
||||||
torch.quantization.convert(model, inplace=True)
|
torch.quantization.convert(model, inplace=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=get_quantized_mlp)
|
@register_test_case(module_factory=get_quantized_mlp)
|
||||||
def QuantizedMLP_basic(module, tu: TestUtils):
|
def QuantizedMLP_basic(module, tu: TestUtils):
|
||||||
module.forward(get_quant_model_input())
|
module.forward(get_quant_model_input())
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -13,19 +13,20 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
class TestMultipleTensorReturn(torch.nn.Module):
|
class TestMultipleTensorReturn(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
([-1, -1], torch.int32, True),
|
([-1, -1], torch.int32, True),
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1, -1], torch.bool, True),
|
([-1, -1], torch.bool, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b, c, d, e):
|
def forward(self, a, b, c, d, e):
|
||||||
return a, b, c, d, e
|
return a, b, c, d, e
|
||||||
|
|
||||||
|
@ -37,54 +38,56 @@ def TestMultipleTensorReturn_basic(module, tu: TestUtils):
|
||||||
tu.rand(2, 3).to(torch.float64),
|
tu.rand(2, 3).to(torch.float64),
|
||||||
tu.rand(2, 3).to(torch.int32),
|
tu.rand(2, 3).to(torch.int32),
|
||||||
tu.rand(2, 3).to(torch.int64),
|
tu.rand(2, 3).to(torch.int64),
|
||||||
tu.rand(2, 3).to(torch.bool))
|
tu.rand(2, 3).to(torch.bool),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestMultipleTensorAndPrimitiveTypesReturn(torch.nn.Module):
|
class TestMultipleTensorAndPrimitiveTypesReturn(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int32, True),
|
([-1, -1], torch.int32, True),
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
([-1, -1], torch.bool, True),
|
([-1, -1], torch.bool, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b, c):
|
def forward(self, a, b, c):
|
||||||
d = 1
|
d = 1
|
||||||
e = 2.3
|
e = 2.3
|
||||||
return a, b, c, d, e
|
return a, b, c, d, e
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: TestMultipleTensorAndPrimitiveTypesReturn())
|
||||||
module_factory=lambda: TestMultipleTensorAndPrimitiveTypesReturn())
|
|
||||||
def TestMultipleTensorAndPrimitiveTypesReturn_basic(module, tu: TestUtils):
|
def TestMultipleTensorAndPrimitiveTypesReturn_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(3, 4).to(torch.int32),
|
tu.rand(3, 4).to(torch.int32),
|
||||||
tu.rand(2, 3).to(torch.float64),
|
tu.rand(2, 3).to(torch.float64),
|
||||||
tu.rand(2, 3).to(torch.bool))
|
tu.rand(2, 3).to(torch.bool),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestF16Return(torch.nn.Module):
|
class TestF16Return(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float16, True),
|
([-1, -1], torch.float16, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: TestF16Return())
|
||||||
module_factory=lambda: TestF16Return())
|
|
||||||
def TestF16Return_basic(module, tu: TestUtils):
|
def TestF16Return_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(3, 4).to(torch.float16))
|
||||||
tu.rand(3, 4).to(torch.float16))
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
|
@ -6,16 +6,13 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class RandModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class RandModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([None, ([1024, 512], torch.float, True)])
|
||||||
None,
|
|
||||||
([1024, 512], torch.float, True)
|
|
||||||
])
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
size = x.size()
|
size = x.size()
|
||||||
a = torch.rand(size)
|
a = torch.rand(size)
|
||||||
|
@ -26,34 +23,41 @@ class RandModule(torch.nn.Module):
|
||||||
def RandModule_basic(module, tu: TestUtils):
|
def RandModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1024, 512))
|
module.forward(tu.rand(1024, 512))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class UniformModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class UniformModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y, z):
|
def forward(self, x, y, z):
|
||||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||||
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
||||||
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
||||||
std = torch.cat([
|
std = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.std(a)),
|
torch.flatten(torch.std(a)),
|
||||||
torch.flatten(torch.std(b)),
|
torch.flatten(torch.std(b)),
|
||||||
torch.flatten(torch.std(c))
|
torch.flatten(torch.std(c)),
|
||||||
])
|
]
|
||||||
mean = torch.cat([
|
)
|
||||||
|
mean = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.mean(a)),
|
torch.flatten(torch.mean(a)),
|
||||||
torch.flatten(torch.mean(b)),
|
torch.flatten(torch.mean(b)),
|
||||||
torch.flatten(torch.mean(c))
|
torch.flatten(torch.mean(c)),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
return std, mean
|
return std, mean
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,36 +66,44 @@ def UniformModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(256, 512, 12).double(),
|
tu.rand(256, 512, 12).double(),
|
||||||
tu.rand(512, 1024, 12).double(),
|
tu.rand(512, 1024, 12).double(),
|
||||||
tu.rand(512, 256, 12).double())
|
tu.rand(512, 256, 12).double(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class UniformStaticShapeModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class UniformStaticShapeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([256, 512, 12], torch.float64, True),
|
([256, 512, 12], torch.float64, True),
|
||||||
([512, 1024, 12], torch.float64, True),
|
([512, 1024, 12], torch.float64, True),
|
||||||
([512, 256, 12], torch.float64, True),
|
([512, 256, 12], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y, z):
|
def forward(self, x, y, z):
|
||||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||||
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
||||||
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
||||||
std = torch.cat([
|
std = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.std(a)),
|
torch.flatten(torch.std(a)),
|
||||||
torch.flatten(torch.std(b)),
|
torch.flatten(torch.std(b)),
|
||||||
torch.flatten(torch.std(c))
|
torch.flatten(torch.std(c)),
|
||||||
])
|
]
|
||||||
mean = torch.cat([
|
)
|
||||||
|
mean = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.mean(a)),
|
torch.flatten(torch.mean(a)),
|
||||||
torch.flatten(torch.mean(b)),
|
torch.flatten(torch.mean(b)),
|
||||||
torch.flatten(torch.mean(c))
|
torch.flatten(torch.mean(c)),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
return std, mean
|
return std, mean
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,12 +112,14 @@ def UniformStaticShapeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(256, 512, 12).double(),
|
tu.rand(256, 512, 12).double(),
|
||||||
tu.rand(512, 1024, 12).double(),
|
tu.rand(512, 1024, 12).double(),
|
||||||
tu.rand(512, 256, 12).double())
|
tu.rand(512, 256, 12).double(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class UniformNoCorrelationModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class UniformNoCorrelationModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -119,10 +133,12 @@ class UniformNoCorrelationModule(torch.nn.Module):
|
||||||
return cov[0, 1] / torch.sqrt(cov[0, 0] * cov[1, 1])
|
return cov[0, 1] / torch.sqrt(cov[0, 0] * cov[1, 1])
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1000], torch.float64, True),
|
([1000], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Correlation of two independent uniforms
|
# Correlation of two independent uniforms
|
||||||
a = torch.ops.aten.uniform(x)
|
a = torch.ops.aten.uniform(x)
|
||||||
|
@ -145,27 +161,32 @@ class UniformNoCorrelationModule(torch.nn.Module):
|
||||||
# than `atol + rtol * correlation = 1E-6`, which is too strict.
|
# than `atol + rtol * correlation = 1E-6`, which is too strict.
|
||||||
# Instead, the correlations are explicitly required to be less than
|
# Instead, the correlations are explicitly required to be less than
|
||||||
# 0.001.
|
# 0.001.
|
||||||
return torch.where(torch.abs(corr_a_b) < 0.001, 1, 2), \
|
return (
|
||||||
torch.where(torch.abs(corr_major) < 0.001, 1, 2), \
|
torch.where(torch.abs(corr_a_b) < 0.001, 1, 2),
|
||||||
torch.where(torch.abs(corr_minor) < 0.001, 1, 2)
|
torch.where(torch.abs(corr_major) < 0.001, 1, 2),
|
||||||
|
torch.where(torch.abs(corr_minor) < 0.001, 1, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: UniformNoCorrelationModule())
|
@register_test_case(module_factory=lambda: UniformNoCorrelationModule())
|
||||||
def UniformNoCorrelationModule_basic(module, tu: TestUtils):
|
def UniformNoCorrelationModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(1000).double())
|
||||||
tu.rand(1000).double())
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ExponentialModule(torch.nn.Module):
|
class ExponentialModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.exponential(x, 3.0)
|
a = torch.ops.aten.exponential(x, 3.0)
|
||||||
mean = torch.mean(a)
|
mean = torch.mean(a)
|
||||||
|
@ -175,20 +196,23 @@ class ExponentialModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ExponentialModule())
|
@register_test_case(module_factory=lambda: ExponentialModule())
|
||||||
def ExponentialModule_basic(module, tu: TestUtils):
|
def ExponentialModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(512, 512, 16).double())
|
||||||
tu.rand(512, 512, 16).double())
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BernoulliModule(torch.nn.Module):
|
class BernoulliModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.bernoulli(x)
|
a = torch.bernoulli(x)
|
||||||
mean = torch.mean(a)
|
mean = torch.mean(a)
|
||||||
|
@ -198,20 +222,23 @@ class BernoulliModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BernoulliModule())
|
@register_test_case(module_factory=lambda: BernoulliModule())
|
||||||
def BernoulliModule_basic(module, tu: TestUtils):
|
def BernoulliModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(512, 512, 16).double())
|
||||||
tu.rand(512, 512, 16).double())
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BernoulliZerosModule(torch.nn.Module):
|
class BernoulliZerosModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.bernoulli(x)
|
return torch.bernoulli(x)
|
||||||
|
|
||||||
|
@ -220,17 +247,21 @@ class BernoulliZerosModule(torch.nn.Module):
|
||||||
def BernoulliZerosModule_basic(module, tu: TestUtils):
|
def BernoulliZerosModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.zeros(4, 8).double())
|
module.forward(torch.zeros(4, 8).double())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BernoulliOnesModule(torch.nn.Module):
|
class BernoulliOnesModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.bernoulli(x)
|
return torch.bernoulli(x)
|
||||||
|
|
||||||
|
@ -239,50 +270,60 @@ class BernoulliOnesModule(torch.nn.Module):
|
||||||
def BernoulliOnesModule_basic(module, tu: TestUtils):
|
def BernoulliOnesModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.ones(4, 8).double())
|
module.forward(torch.ones(4, 8).double())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BernoulliFloatModule(torch.nn.Module):
|
class BernoulliFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
a = torch.ops.aten.bernoulli_(x, 0.4)
|
a = torch.ops.aten.bernoulli_(x, 0.4)
|
||||||
b = torch.ops.aten.bernoulli_(y, 0.7)
|
b = torch.ops.aten.bernoulli_(y, 0.7)
|
||||||
mean = torch.cat([
|
mean = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.mean(a)),
|
torch.flatten(torch.mean(a)),
|
||||||
torch.flatten(torch.mean(b)),
|
torch.flatten(torch.mean(b)),
|
||||||
])
|
]
|
||||||
std = torch.cat([
|
)
|
||||||
|
std = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.std(a)),
|
torch.flatten(torch.std(a)),
|
||||||
torch.flatten(torch.std(b)),
|
torch.flatten(torch.std(b)),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
||||||
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(512, 512, 16).double(), tu.rand(512, 512, 16).double())
|
||||||
tu.rand(512, 512, 16).double(),
|
|
||||||
tu.rand(512, 512, 16).double())
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BernoulliTensorModule(torch.nn.Module):
|
class BernoulliTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, px):
|
def forward(self, x, px):
|
||||||
a = torch.ops.aten.bernoulli_(x, px)
|
a = torch.ops.aten.bernoulli_(x, px)
|
||||||
mean = torch.mean(a)
|
mean = torch.mean(a)
|
||||||
|
@ -292,53 +333,61 @@ class BernoulliTensorModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(512, 512).double(), tu.rand(512, 512).double())
|
||||||
tu.rand(512, 512).double(),
|
|
||||||
tu.rand(512, 512).double())
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BernoulliPModule(torch.nn.Module):
|
class BernoulliPModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
a = torch.ops.aten.bernoulli(x, 0.4)
|
a = torch.ops.aten.bernoulli(x, 0.4)
|
||||||
b = torch.ops.aten.bernoulli(y, 0.7)
|
b = torch.ops.aten.bernoulli(y, 0.7)
|
||||||
mean = torch.cat([
|
mean = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.mean(a)),
|
torch.flatten(torch.mean(a)),
|
||||||
torch.flatten(torch.mean(b)),
|
torch.flatten(torch.mean(b)),
|
||||||
])
|
]
|
||||||
std = torch.cat([
|
)
|
||||||
|
std = torch.cat(
|
||||||
|
[
|
||||||
torch.flatten(torch.std(a)),
|
torch.flatten(torch.std(a)),
|
||||||
torch.flatten(torch.std(b)),
|
torch.flatten(torch.std(b)),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BernoulliPModule())
|
@register_test_case(module_factory=lambda: BernoulliPModule())
|
||||||
def BernoulliPModule_basic(module, tu: TestUtils):
|
def BernoulliPModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.rand(512, 512, 16).double(), tu.rand(512, 512, 16).double())
|
||||||
tu.rand(512, 512, 16).double(),
|
|
||||||
tu.rand(512, 512, 16).double())
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandLikeModule(torch.nn.Module):
|
class RandLikeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.rand_like(x)
|
a = torch.ops.aten.rand_like(x)
|
||||||
mean = torch.mean(a)
|
mean = torch.mean(a)
|
||||||
|
@ -349,17 +398,21 @@ class RandLikeModule(torch.nn.Module):
|
||||||
def RandLikeModule_basic(module, tu: TestUtils):
|
def RandLikeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1024, 1024).double())
|
module.forward(tu.rand(1024, 1024).double())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandLikeDtypeModule(torch.nn.Module):
|
class RandLikeDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.rand_like(x, dtype=torch.float32)
|
a = torch.ops.aten.rand_like(x, dtype=torch.float32)
|
||||||
mean = torch.mean(a)
|
mean = torch.mean(a)
|
||||||
|
@ -370,16 +423,20 @@ class RandLikeDtypeModule(torch.nn.Module):
|
||||||
def RandLikeDtypeModule_basic(module, tu: TestUtils):
|
def RandLikeDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1024, 1024).double())
|
module.forward(tu.rand(1024, 1024).double())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandIntLowModule(torch.nn.Module):
|
class RandIntLowModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randint(low=1, high=1000, size=[1024, 1024])
|
a = torch.ops.aten.randint(low=1, high=1000, size=[1024, 1024])
|
||||||
mean = torch.mean(a.to(torch.float32))
|
mean = torch.mean(a.to(torch.float32))
|
||||||
|
@ -390,18 +447,24 @@ class RandIntLowModule(torch.nn.Module):
|
||||||
def RandIntLowModule_basic(module, tu: TestUtils):
|
def RandIntLowModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandIntLowDtypeModule(torch.nn.Module):
|
class RandIntLowDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randint(low=1, high=1000, size=[128, 256, 512], dtype=torch.float64)
|
a = torch.ops.aten.randint(
|
||||||
|
low=1, high=1000, size=[128, 256, 512], dtype=torch.float64
|
||||||
|
)
|
||||||
mean = torch.mean(a)
|
mean = torch.mean(a)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
@ -410,16 +473,20 @@ class RandIntLowDtypeModule(torch.nn.Module):
|
||||||
def RandIntLowDtypeModule_basic(module, tu: TestUtils):
|
def RandIntLowDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandIntModule(torch.nn.Module):
|
class RandIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randint(high=1000, size=[1024, 1024])
|
a = torch.ops.aten.randint(high=1000, size=[1024, 1024])
|
||||||
mean = torch.mean(a.to(torch.float32))
|
mean = torch.mean(a.to(torch.float32))
|
||||||
|
@ -430,16 +497,20 @@ class RandIntModule(torch.nn.Module):
|
||||||
def RandIntModule_basic(module, tu: TestUtils):
|
def RandIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandIntDtypeModule(torch.nn.Module):
|
class RandIntDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], dtype=torch.float64)
|
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], dtype=torch.float64)
|
||||||
mean = torch.mean(a.to(torch.float32))
|
mean = torch.mean(a.to(torch.float32))
|
||||||
|
@ -450,16 +521,20 @@ class RandIntDtypeModule(torch.nn.Module):
|
||||||
def RandIntDtypeModule_basic(module, tu: TestUtils):
|
def RandIntDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandIntPinMemoryModule(torch.nn.Module):
|
class RandIntPinMemoryModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], pin_memory=False)
|
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], pin_memory=False)
|
||||||
mean = torch.mean(a.to(torch.float32))
|
mean = torch.mean(a.to(torch.float32))
|
||||||
|
@ -470,17 +545,20 @@ class RandIntPinMemoryModule(torch.nn.Module):
|
||||||
def RandIntPinMemoryModule_basic(module, tu: TestUtils):
|
def RandIntPinMemoryModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class RandnModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class RandnModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randn([4, 512, 1024])
|
a = torch.ops.aten.randn([4, 512, 1024])
|
||||||
std = torch.std(a.to(dtype=torch.float64))
|
std = torch.std(a.to(dtype=torch.float64))
|
||||||
|
@ -496,18 +574,19 @@ def RandnModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class RandnDtypeDeviceModule(torch.nn.Module):
|
class RandnDtypeDeviceModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randn([4, 512, 1024],
|
a = torch.ops.aten.randn(
|
||||||
dtype=torch.float64,
|
[4, 512, 1024], dtype=torch.float64, device=torch.device("cpu")
|
||||||
device=torch.device("cpu"))
|
)
|
||||||
std = torch.std(a)
|
std = torch.std(a)
|
||||||
return std
|
return std
|
||||||
|
|
||||||
|
@ -521,14 +600,15 @@ def RandnDtypeDeviceModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class RandnGeneratorModule(torch.nn.Module):
|
class RandnGeneratorModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randn([4, 512, 1024], generator=None)
|
a = torch.ops.aten.randn([4, 512, 1024], generator=None)
|
||||||
std = torch.std(a.to(dtype=torch.float64))
|
std = torch.std(a.to(dtype=torch.float64))
|
||||||
|
@ -544,14 +624,15 @@ def RandnGeneratorModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class RandnGeneratorF64Module(torch.nn.Module):
|
class RandnGeneratorF64Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64)
|
a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64)
|
||||||
std = torch.std(a)
|
std = torch.std(a)
|
||||||
|
@ -571,10 +652,12 @@ class RandnLikeModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.randn_like(x)
|
a = torch.ops.aten.randn_like(x)
|
||||||
std = torch.std(a)
|
std = torch.std(a)
|
||||||
|
@ -585,17 +668,21 @@ class RandnLikeModule(torch.nn.Module):
|
||||||
def RandnLikeModule_basic(module, tu: TestUtils):
|
def RandnLikeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 512, 1024).double())
|
module.forward(tu.rand(4, 512, 1024).double())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RandnLikeDtypeModule(torch.nn.Module):
|
class RandnLikeDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.randn_like(x, dtype=torch.float32)
|
a = torch.ops.aten.randn_like(x, dtype=torch.float32)
|
||||||
std = torch.std(a)
|
std = torch.std(a)
|
||||||
|
@ -605,17 +692,22 @@ class RandnLikeDtypeModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
|
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
|
||||||
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
|
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(256, 1024).double())
|
module.forward(tu.rand(256, 1024).double())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NormalFunctionalModule(torch.nn.Module):
|
class NormalFunctionalModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
|
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
|
||||||
mean = torch.mean(a)
|
mean = torch.mean(a)
|
||||||
|
|
|
@ -13,16 +13,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
class AddIntModule(torch.nn.Module):
|
class AddIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs) + int(rhs)
|
return int(lhs) + int(rhs)
|
||||||
|
|
||||||
|
@ -36,16 +37,17 @@ def AddIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class SubIntModule(torch.nn.Module):
|
class SubIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs) - int(rhs)
|
return int(lhs) - int(rhs)
|
||||||
|
|
||||||
|
@ -59,16 +61,17 @@ def SubIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class SubFloatModule(torch.nn.Module):
|
class SubFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs) - float(rhs)
|
return float(lhs) - float(rhs)
|
||||||
|
|
||||||
|
@ -80,17 +83,19 @@ def SubFloatModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class MulFloatModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class MulFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs) * float(rhs)
|
return float(lhs) * float(rhs)
|
||||||
|
|
||||||
|
@ -104,16 +109,17 @@ def MulFloatModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class MulIntModule(torch.nn.Module):
|
class MulIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs) * int(rhs)
|
return int(lhs) * int(rhs)
|
||||||
|
|
||||||
|
@ -127,16 +133,17 @@ def MulIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class DivIntModule(torch.nn.Module):
|
class DivIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
# Cast the result to float to make e2e test baseline result to be a float.
|
# Cast the result to float to make e2e test baseline result to be a float.
|
||||||
# Without the cast, baseline result is a Tensor which is unexpected.
|
# Without the cast, baseline result is a Tensor which is unexpected.
|
||||||
|
@ -152,16 +159,17 @@ def DivIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class DivFloatModule(torch.nn.Module):
|
class DivFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs) / float(rhs)
|
return float(lhs) / float(rhs)
|
||||||
|
|
||||||
|
@ -175,16 +183,17 @@ def DivFloatModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class CeilFloatModule(torch.nn.Module):
|
class CeilFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
sub = float(lhs) - float(rhs)
|
sub = float(lhs) - float(rhs)
|
||||||
# Cast the result to int to make e2e test baseline result to be an int.
|
# Cast the result to int to make e2e test baseline result to be an int.
|
||||||
|
@ -204,15 +213,16 @@ def CeilFloatModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class SqrtIntModule(torch.nn.Module):
|
class SqrtIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return float(torch.ops.aten.sqrt(int(a)))
|
return float(torch.ops.aten.sqrt(int(a)))
|
||||||
|
|
||||||
|
@ -223,14 +233,15 @@ def SqrtIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class SqrtIntConstantModule(torch.nn.Module):
|
class SqrtIntConstantModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return float(torch.ops.aten.sqrt(5))
|
return float(torch.ops.aten.sqrt(5))
|
||||||
|
|
||||||
|
@ -244,15 +255,16 @@ def SqrtIntConstantModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class BoolFloatFalseModule(torch.nn.Module):
|
class BoolFloatFalseModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
sub = float(a) - float(a)
|
sub = float(a) - float(a)
|
||||||
return bool(torch.ops.aten.Bool(float(sub)))
|
return bool(torch.ops.aten.Bool(float(sub)))
|
||||||
|
@ -264,15 +276,16 @@ def BoolFloatFalseModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class BoolFloatTrueModule(torch.nn.Module):
|
class BoolFloatTrueModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return bool(torch.ops.aten.Bool(float(a)))
|
return bool(torch.ops.aten.Bool(float(a)))
|
||||||
|
|
||||||
|
@ -283,14 +296,15 @@ def BoolFloatTrueModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class BoolFloatConstantModule(torch.nn.Module):
|
class BoolFloatConstantModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return bool(torch.ops.aten.Bool(5.0))
|
return bool(torch.ops.aten.Bool(5.0))
|
||||||
|
|
||||||
|
@ -304,15 +318,16 @@ def BoolFloatConstantModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class BoolIntFalseModule(torch.nn.Module):
|
class BoolIntFalseModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
sub = int(a) - int(a)
|
sub = int(a) - int(a)
|
||||||
return bool(torch.ops.aten.Bool(int(sub)))
|
return bool(torch.ops.aten.Bool(int(sub)))
|
||||||
|
@ -324,15 +339,16 @@ def BoolIntFalseModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class BoolIntTrueModule(torch.nn.Module):
|
class BoolIntTrueModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return bool(torch.ops.aten.Bool(int(a)))
|
return bool(torch.ops.aten.Bool(int(a)))
|
||||||
|
|
||||||
|
@ -343,14 +359,15 @@ def BoolIntTrueModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class BoolIntConstantModule(torch.nn.Module):
|
class BoolIntConstantModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return bool(torch.ops.aten.Bool(5))
|
return bool(torch.ops.aten.Bool(5))
|
||||||
|
|
||||||
|
@ -359,17 +376,21 @@ class BoolIntConstantModule(torch.nn.Module):
|
||||||
def BoolIntConstantModule_basic(module, tu: TestUtils):
|
def BoolIntConstantModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenIntBoolOpModule(torch.nn.Module):
|
class AtenIntBoolOpModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.bool, True),
|
([], torch.bool, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return int(torch.ops.aten.Int(x))
|
return int(torch.ops.aten.Int(x))
|
||||||
|
|
||||||
|
@ -384,9 +405,11 @@ class AtenIntBoolOpConstTrueModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return int(torch.ops.aten.Int(True))
|
return int(torch.ops.aten.Int(True))
|
||||||
|
|
||||||
|
@ -401,9 +424,11 @@ class AtenIntBoolOpConstFalseModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return int(torch.ops.aten.Int(False))
|
return int(torch.ops.aten.Int(False))
|
||||||
|
|
||||||
|
@ -412,21 +437,25 @@ class AtenIntBoolOpConstFalseModule(torch.nn.Module):
|
||||||
def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils):
|
def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.uint8, True),
|
([], torch.uint8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, val):
|
def forward(self, val):
|
||||||
return int(val)
|
return int(val)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule())
|
@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule())
|
||||||
def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8))
|
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8))
|
||||||
|
@ -434,57 +463,68 @@ def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenIntTensorCharDtypeModule(torch.nn.Module):
|
class AtenIntTensorCharDtypeModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int8, True),
|
([], torch.int8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, val):
|
def forward(self, val):
|
||||||
return int(val)
|
return int(val)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
|
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
|
||||||
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
|
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenItemIntOpModule(torch.nn.Module):
|
class AtenItemIntOpModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int8, True),
|
([], torch.int8, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, val):
|
def forward(self, val):
|
||||||
return int(val)
|
return int(val)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenItemIntOpModule())
|
@register_test_case(module_factory=lambda: AtenItemIntOpModule())
|
||||||
def AtenItemIntOpModule_basic(module, tu: TestUtils):
|
def AtenItemIntOpModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenItemFpOpModule(torch.nn.Module):
|
class AtenItemFpOpModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float, True),
|
([], torch.float, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, val):
|
def forward(self, val):
|
||||||
return float(val)
|
return float(val)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenItemFpOpModule())
|
@register_test_case(module_factory=lambda: AtenItemFpOpModule())
|
||||||
def AtenItemFpOpModule_basic(module, tu: TestUtils):
|
def AtenItemFpOpModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1))
|
module.forward(tu.rand(1))
|
||||||
|
|
|
@ -13,16 +13,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
class NeIntModule(torch.nn.Module):
|
class NeIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs) != int(rhs)
|
return int(lhs) != int(rhs)
|
||||||
|
|
||||||
|
@ -36,16 +37,17 @@ def NeIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class EqIntModule(torch.nn.Module):
|
class EqIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs) == int(rhs)
|
return int(lhs) == int(rhs)
|
||||||
|
|
||||||
|
@ -59,16 +61,17 @@ def EqIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class GtIntModule(torch.nn.Module):
|
class GtIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs) > int(rhs)
|
return int(lhs) > int(rhs)
|
||||||
|
|
||||||
|
@ -82,16 +85,17 @@ def GtIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class GeIntModule(torch.nn.Module):
|
class GeIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return torch.ops.aten.ge(int(lhs), int(rhs))
|
return torch.ops.aten.ge(int(lhs), int(rhs))
|
||||||
|
|
||||||
|
@ -105,16 +109,17 @@ def GeIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class GeFloatModule(torch.nn.Module):
|
class GeFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs) >= float(rhs)
|
return float(lhs) >= float(rhs)
|
||||||
|
|
||||||
|
@ -128,16 +133,17 @@ def GeFloatModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class GeFloatIntModule(torch.nn.Module):
|
class GeFloatIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs) >= int(rhs)
|
return float(lhs) >= int(rhs)
|
||||||
|
|
||||||
|
@ -151,16 +157,17 @@ def GeFloatIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class NeFloatIntModule(torch.nn.Module):
|
class NeFloatIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs) != int(rhs)
|
return float(lhs) != int(rhs)
|
||||||
|
|
||||||
|
@ -174,16 +181,17 @@ def NeFloatIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class GtFloatIntModule(torch.nn.Module):
|
class GtFloatIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs) > int(rhs)
|
return float(lhs) > int(rhs)
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -17,16 +17,17 @@ class SqueezeStaticModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 7, 1, 3, 1], torch.float32, True),
|
([1, 7, 1, 3, 1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.squeeze(a)
|
return torch.squeeze(a)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeStaticModule())
|
||||||
module_factory=lambda: SqueezeStaticModule())
|
|
||||||
def SqueezeModule_static(module, tu: TestUtils):
|
def SqueezeModule_static(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 7, 1, 3, 1))
|
module.forward(tu.rand(1, 7, 1, 3, 1))
|
||||||
|
|
||||||
|
@ -39,16 +40,17 @@ class SqueezeAllUnitDimModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 1], torch.float32, True),
|
([1, 1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.squeeze(a)
|
return torch.squeeze(a)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeAllUnitDimModule())
|
||||||
module_factory=lambda: SqueezeAllUnitDimModule())
|
|
||||||
def SqueezeModule_allUnitDim(module, tu: TestUtils):
|
def SqueezeModule_allUnitDim(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1))
|
module.forward(tu.rand(1, 1))
|
||||||
|
|
||||||
|
@ -61,17 +63,18 @@ class SqueezeBroadcastModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return a * b.squeeze()
|
return a * b.squeeze()
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeBroadcastModule())
|
||||||
module_factory=lambda: SqueezeBroadcastModule())
|
|
||||||
def SqueezeModule_broadcast(module, tu: TestUtils):
|
def SqueezeModule_broadcast(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 3), tu.rand())
|
module.forward(tu.rand(4, 3), tu.rand())
|
||||||
|
|
||||||
|
@ -84,16 +87,17 @@ class SqueezeDimStaticModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 7], torch.float32, True),
|
([1, 7], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.squeeze(a, 0)
|
return torch.squeeze(a, 0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeDimStaticModule())
|
||||||
module_factory=lambda: SqueezeDimStaticModule())
|
|
||||||
def SqueezeDimModule_static(module, tu: TestUtils):
|
def SqueezeDimModule_static(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 7))
|
module.forward(tu.rand(1, 7))
|
||||||
|
|
||||||
|
@ -106,16 +110,17 @@ class SqueezeDimDynamicModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, 1, 384, -1, 1], torch.float32, True),
|
([-1, 1, 384, -1, 1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.squeeze(a, 4)
|
return torch.squeeze(a, 4)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeDimDynamicModule())
|
||||||
module_factory=lambda: SqueezeDimDynamicModule())
|
|
||||||
def SqueezeDimModule_dynamic(module, tu: TestUtils):
|
def SqueezeDimModule_dynamic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(8, 1, 384, 12, 1))
|
module.forward(tu.rand(8, 1, 384, 12, 1))
|
||||||
|
|
||||||
|
@ -128,16 +133,17 @@ class SqueezeDimNegDimModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, -1, 1, 384, -1, 1], torch.float32, True),
|
([1, -1, 1, 384, -1, 1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.squeeze(a, -6)
|
return torch.squeeze(a, -6)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeDimNegDimModule())
|
||||||
module_factory=lambda: SqueezeDimNegDimModule())
|
|
||||||
def SqueezeDimModule_negDim(module, tu: TestUtils):
|
def SqueezeDimModule_negDim(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
|
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
|
||||||
|
|
||||||
|
@ -150,16 +156,17 @@ class SqueezeDimIdentityModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([4, 1, -1], torch.float32, True),
|
([4, 1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.squeeze(a, 0)
|
return torch.squeeze(a, 0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeDimIdentityModule())
|
||||||
module_factory=lambda: SqueezeDimIdentityModule())
|
|
||||||
def SqueezeDimModule_identity(module, tu: TestUtils):
|
def SqueezeDimModule_identity(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 1, 3))
|
module.forward(tu.rand(4, 1, 3))
|
||||||
|
|
||||||
|
@ -172,16 +179,17 @@ class SqueezeDimUnitDimModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1], torch.float32, True),
|
([1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.squeeze(a, 0)
|
return torch.squeeze(a, 0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: SqueezeDimUnitDimModule())
|
||||||
module_factory=lambda: SqueezeDimUnitDimModule())
|
|
||||||
def SqueezeDimModule_unitDim(module, tu: TestUtils):
|
def SqueezeDimModule_unitDim(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1))
|
module.forward(tu.rand(1))
|
||||||
|
|
||||||
|
@ -194,16 +202,17 @@ class PrimsSqueezeModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 1, 2, 3, 1, 4], torch.float32, True),
|
([1, 1, 2, 3, 1, 4], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.prims.squeeze(a, dimensions=[0, 4, 1])
|
return torch.ops.prims.squeeze(a, dimensions=[0, 4, 1])
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: PrimsSqueezeModule())
|
||||||
module_factory=lambda: PrimsSqueezeModule())
|
|
||||||
def PrimsSqueezeModule_basic(module, tu: TestUtils):
|
def PrimsSqueezeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1, 2, 3, 1, 4))
|
module.forward(tu.rand(1, 1, 2, 3, 1, 4))
|
||||||
|
|
||||||
|
@ -213,15 +222,16 @@ class PrimsSqueezeEmptyDimensionsModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 2, 1, 4], torch.float32, True),
|
([1, 2, 1, 4], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.ops.prims.squeeze(a, dimensions=[])
|
return torch.ops.prims.squeeze(a, dimensions=[])
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: PrimsSqueezeEmptyDimensionsModule())
|
||||||
module_factory=lambda: PrimsSqueezeEmptyDimensionsModule())
|
|
||||||
def PrimsSqueezeEmptyDimensionsModule_basic(module, tu: TestUtils):
|
def PrimsSqueezeEmptyDimensionsModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 2, 1, 4))
|
module.forward(tu.rand(1, 2, 1, 4))
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,14 +17,16 @@ class Threshold1dIntI32Module(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int32, True),
|
([-1], torch.int32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.ops.aten.threshold(input, 1, 2)
|
return torch.ops.aten.threshold(input, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
|
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
|
||||||
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
|
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, high=10).to(torch.int32))
|
module.forward(tu.randint(4, high=10).to(torch.int32))
|
||||||
|
@ -35,14 +37,16 @@ class Threshold1dIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.ops.aten.threshold(input, 1, 2)
|
return torch.ops.aten.threshold(input, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Threshold1dIntModule())
|
@register_test_case(module_factory=lambda: Threshold1dIntModule())
|
||||||
def Threshold1dIntModule_basic(module, tu: TestUtils):
|
def Threshold1dIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, high=10))
|
module.forward(tu.randint(4, high=10))
|
||||||
|
@ -53,14 +57,16 @@ class Threshold2dIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Threshold2dIntModule())
|
@register_test_case(module_factory=lambda: Threshold2dIntModule())
|
||||||
def Threshold2dIntModule_basic(module, tu: TestUtils):
|
def Threshold2dIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, 5, high=10))
|
module.forward(tu.randint(4, 5, high=10))
|
||||||
|
@ -71,14 +77,16 @@ class Threshold3dIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.int64, True),
|
([-1, -1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.ops.aten.threshold(input, 1, 2.2)
|
return torch.ops.aten.threshold(input, 1, 2.2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Threshold3dIntModule())
|
@register_test_case(module_factory=lambda: Threshold3dIntModule())
|
||||||
def Threshold3dIntModule_basic(module, tu: TestUtils):
|
def Threshold3dIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, 5, 6, high=10))
|
module.forward(tu.randint(4, 5, 6, high=10))
|
||||||
|
@ -89,14 +97,16 @@ class Threshold1dFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.ops.aten.threshold(input, 1, 2)
|
return torch.ops.aten.threshold(input, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
|
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
|
||||||
def Threshold1dFloatModule_basic(module, tu: TestUtils):
|
def Threshold1dFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4))
|
module.forward(tu.rand(4))
|
||||||
|
@ -107,14 +117,16 @@ class Threshold2dFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
|
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
|
||||||
def Threshold2dFloatModule_basic(module, tu: TestUtils):
|
def Threshold2dFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5))
|
module.forward(tu.rand(4, 5))
|
||||||
|
@ -125,14 +137,16 @@ class Threshold3dFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.ops.aten.threshold(input, 1.4, 2.0)
|
return torch.ops.aten.threshold(input, 1.4, 2.0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
|
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
|
||||||
def Threshold3dFloatModule_basic(module, tu: TestUtils):
|
def Threshold3dFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6))
|
module.forward(tu.rand(4, 5, 6))
|
||||||
|
@ -143,15 +157,17 @@ class ThresholdBackward1dIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
|
||||||
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
|
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, high=10), tu.randint(4, high=8))
|
module.forward(tu.randint(4, high=10), tu.randint(4, high=8))
|
||||||
|
@ -162,15 +178,17 @@ class ThresholdBackward2dIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
|
||||||
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
|
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8))
|
module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8))
|
||||||
|
@ -181,15 +199,17 @@ class ThresholdBackward3dIntModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.int64, True),
|
([-1, -1, -1], torch.int64, True),
|
||||||
([-1, -1, -1], torch.int64, True),
|
([-1, -1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
|
||||||
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
|
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8))
|
module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8))
|
||||||
|
@ -200,15 +220,17 @@ class ThresholdBackward1dFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
|
||||||
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
|
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4), tu.rand(4))
|
module.forward(tu.rand(4), tu.rand(4))
|
||||||
|
@ -219,15 +241,17 @@ class ThresholdBackward2dFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
|
||||||
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
|
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5), tu.rand(4, 5))
|
module.forward(tu.rand(4, 5), tu.rand(4, 5))
|
||||||
|
@ -238,15 +262,17 @@ class ThresholdBackward3dFloatModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
|
||||||
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
|
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6))
|
module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6))
|
||||||
|
@ -257,15 +283,17 @@ class ThresholdBackward1dMixedModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
||||||
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4), tu.randint(4, high=10))
|
module.forward(tu.rand(4), tu.randint(4, high=10))
|
||||||
|
@ -276,15 +304,17 @@ class ThresholdBackward2dMixedModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
||||||
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, 5, high=20), tu.rand(4, 5))
|
module.forward(tu.randint(4, 5, high=20), tu.rand(4, 5))
|
||||||
|
@ -295,15 +325,17 @@ class ThresholdBackward3dMixedModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1], torch.int64, True),
|
([-1, -1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, grad, input):
|
def forward(self, grad, input):
|
||||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
||||||
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6), tu.randint(4, 5, 6, high=10))
|
module.forward(tu.rand(4, 5, 6), tu.randint(4, 5, 6, high=10))
|
||||||
|
|
|
@ -13,7 +13,6 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionF32ToF64Module(torch.nn.Module):
|
class TypeConversionF32ToF64Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -29,7 +28,6 @@ def TypeConversionF32ToF64Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionF64ToF32Module(torch.nn.Module):
|
class TypeConversionF64ToF32Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -45,7 +43,6 @@ def TypeConversionF64ToF32Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionI32ToI64Module(torch.nn.Module):
|
class TypeConversionI32ToI64Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -61,7 +58,6 @@ def TypeConversionI32ToI64Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionI64ToI32Module(torch.nn.Module):
|
class TypeConversionI64ToI32Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -77,7 +73,6 @@ def TypeConversionI64ToI32Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionI1ToI32Module(torch.nn.Module):
|
class TypeConversionI1ToI32Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -94,7 +89,6 @@ def TypeConversionI1ToI32Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionI1ToI64Module(torch.nn.Module):
|
class TypeConversionI1ToI64Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -111,7 +105,6 @@ def TypeConversionI1ToI64Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionI1ToF32Module(torch.nn.Module):
|
class TypeConversionI1ToF32Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -128,7 +121,6 @@ def TypeConversionI1ToF32Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeConversionI1ToF64Module(torch.nn.Module):
|
class TypeConversionI1ToF64Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -148,43 +140,46 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class ToDtypeLayoutNoneModule(torch.nn.Module):
|
class ToDtypeLayoutNoneModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.to(x,
|
return torch.ops.aten.to(
|
||||||
|
x,
|
||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
layout=None,
|
layout=None,
|
||||||
device=None,
|
device=None,
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
non_blocking=False,
|
non_blocking=False,
|
||||||
copy=False,
|
copy=False,
|
||||||
memory_format=None)
|
memory_format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ToDtypeLayoutNoneModule())
|
@register_test_case(module_factory=lambda: ToDtypeLayoutNoneModule())
|
||||||
def ToDtypeLayoutNoneModule_basic(module, tu: TestUtils):
|
def ToDtypeLayoutNoneModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
class ToDtypeLayoutCPUModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class ToDtypeLayoutCPUModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.to(x,
|
return torch.ops.aten.to(
|
||||||
|
x,
|
||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
layout=None,
|
layout=None,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
non_blocking=False,
|
non_blocking=False,
|
||||||
copy=False,
|
copy=False,
|
||||||
memory_format=None)
|
memory_format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ToDtypeLayoutCPUModule())
|
@register_test_case(module_factory=lambda: ToDtypeLayoutCPUModule())
|
||||||
|
@ -193,21 +188,22 @@ def ToDtypeLayoutCPUModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class ToDtypeLayoutStridedModule(torch.nn.Module):
|
class ToDtypeLayoutStridedModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.to(x,
|
return torch.ops.aten.to(
|
||||||
|
x,
|
||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
layout=torch.strided,
|
layout=torch.strided,
|
||||||
device=None,
|
device=None,
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
non_blocking=False,
|
non_blocking=False,
|
||||||
copy=False,
|
copy=False,
|
||||||
memory_format=None)
|
memory_format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule())
|
@register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule())
|
||||||
|
@ -216,21 +212,22 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module):
|
class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([None, ([3, 5], torch.int64, True)])
|
@annotate_args([None, ([3, 5], torch.int64, True)])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.to(x,
|
return torch.ops.aten.to(
|
||||||
|
x,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
layout=None,
|
layout=None,
|
||||||
device=None,
|
device=None,
|
||||||
pin_memory=None,
|
pin_memory=None,
|
||||||
non_blocking=False,
|
non_blocking=False,
|
||||||
copy=False,
|
copy=False,
|
||||||
memory_format=None)
|
memory_format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule())
|
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule())
|
||||||
|
@ -239,16 +236,17 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class TypeAsSameModule(torch.nn.Module):
|
class TypeAsSameModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ops.aten.type_as(x, y)
|
return torch.ops.aten.type_as(x, y)
|
||||||
|
|
||||||
|
@ -257,17 +255,19 @@ class TypeAsSameModule(torch.nn.Module):
|
||||||
def TypeAsSameModule_basic(module, tu: TestUtils):
|
def TypeAsSameModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||||
|
|
||||||
class TypeAsDifferentModule(torch.nn.Module):
|
|
||||||
|
|
||||||
|
class TypeAsDifferentModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int, True),
|
([-1, -1], torch.int, True),
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return torch.ops.aten.type_as(x, y)
|
return torch.ops.aten.type_as(x, y)
|
||||||
|
|
||||||
|
@ -276,14 +276,14 @@ class TypeAsDifferentModule(torch.nn.Module):
|
||||||
def TypeAsDifferentModule_basic(module, tu: TestUtils):
|
def TypeAsDifferentModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.randint(3, 5, low=0, high=10, dtype=torch.int),
|
tu.randint(3, 5, low=0, high=10, dtype=torch.int),
|
||||||
tu.randint(3, 5, low=0, high=10, dtype=torch.int64)
|
tu.randint(3, 5, low=0, high=10, dtype=torch.int64),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class PrimsConvertElementTypeModule(torch.nn.Module):
|
class PrimsConvertElementTypeModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -17,21 +17,22 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int32, True),
|
([-1], torch.int32, True),
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.add(a, b, alpha=3)
|
return torch.add(a, b, alpha=3)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
|
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule()
|
||||||
|
)
|
||||||
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
|
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10))
|
||||||
tu.randint(4, high=10).type(torch.int32),
|
|
||||||
tu.randint(4, high=10))
|
|
||||||
|
|
||||||
|
|
||||||
class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
||||||
|
@ -39,17 +40,18 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.add(a, b, alpha=3)
|
return torch.add(a, b, alpha=3)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: TypePromotionDifferentCategoryModule())
|
||||||
module_factory=lambda: TypePromotionDifferentCategoryModule())
|
|
||||||
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, high=10), tu.rand(4))
|
module.forward(tu.randint(4, high=10), tu.rand(4))
|
||||||
|
|
||||||
|
@ -59,17 +61,20 @@ class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.add(a, b, alpha=2.3)
|
return torch.add(a, b, alpha=2.3)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule())
|
module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule()
|
||||||
|
)
|
||||||
def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
|
def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4), tu.rand().type(torch.float64))
|
module.forward(tu.rand(4), tu.rand().type(torch.float64))
|
||||||
|
|
||||||
|
@ -79,17 +84,18 @@ class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.add(a, b, alpha=2)
|
return torch.add(a, b, alpha=2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
|
||||||
module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
|
|
||||||
def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
|
def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(4, high=10), tu.rand())
|
module.forward(tu.randint(4, high=10), tu.rand())
|
||||||
|
|
||||||
|
@ -99,11 +105,13 @@ class TypePromotionAlphaWiderModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
([], torch.float32, True),
|
([], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.add(a, b, alpha=2.3)
|
return torch.add(a, b, alpha=2.3)
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,12 @@ class ResNet18Module(torch.nn.Module):
|
||||||
self.train(False)
|
self.train(False)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, 3, -1, -1], torch.float32, True),
|
([-1, 3, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, img):
|
def forward(self, img):
|
||||||
return self.resnet.forward(img)
|
return self.resnet.forward(img)
|
||||||
|
|
||||||
|
@ -44,10 +46,12 @@ class ResNet18StaticModule(torch.nn.Module):
|
||||||
self.train(False)
|
self.train(False)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([1, 3, 224, 224], torch.float32, True),
|
([1, 3, 224, 224], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, img):
|
def forward(self, img):
|
||||||
return self.resnet.forward(img)
|
return self.resnet.forward(img)
|
||||||
|
|
||||||
|
@ -62,11 +66,13 @@ class IouOfModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, bbox1, bbox2):
|
def forward(self, bbox1, bbox2):
|
||||||
area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1])
|
area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1])
|
||||||
area2 = (bbox2[:, 2] - bbox2[:, 0]) * (bbox2[:, 3] - bbox2[:, 1])
|
area2 = (bbox2[:, 2] - bbox2[:, 0]) * (bbox2[:, 3] - bbox2[:, 1])
|
||||||
|
@ -94,10 +100,12 @@ class MobilenetV3Module(torch.nn.Module):
|
||||||
self.train(False)
|
self.train(False)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args(
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
([-1, 3, -1, -1], torch.float32, True),
|
([-1, 3, -1, -1], torch.float32, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def forward(self, img):
|
def forward(self, img):
|
||||||
return self.mobilenetv3.forward(img)
|
return self.mobilenetv3.forward(img)
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,12 @@ from torch_mlir.ir import Module
|
||||||
# A type shared between the result of `TosaBackend.compile` and the
|
# A type shared between the result of `TosaBackend.compile` and the
|
||||||
# input to `TosaBackend.load`. Each backend will likely have a
|
# input to `TosaBackend.load`. Each backend will likely have a
|
||||||
# different definition of this type.
|
# different definition of this type.
|
||||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||||
|
|
||||||
# A wrapper around a backend-specific loaded program representation
|
# A wrapper around a backend-specific loaded program representation
|
||||||
# that uniformly translates the `x.method(...)` interface expected of
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
# Torch modules into appropriate lower-level operations.
|
# Torch modules into appropriate lower-level operations.
|
||||||
Invoker = TypeVar('Invoker')
|
Invoker = TypeVar("Invoker")
|
||||||
|
|
||||||
|
|
||||||
class TosaBackend(abc.ABC):
|
class TosaBackend(abc.ABC):
|
||||||
|
@ -27,6 +27,7 @@ class TosaBackend(abc.ABC):
|
||||||
Backends are recommended to raise meaningful exceptions in case of error,
|
Backends are recommended to raise meaningful exceptions in case of error,
|
||||||
ideally with easy reproduction instructions.
|
ideally with easy reproduction instructions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compile(self, module: Module) -> CompiledArtifact:
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
"""Compile the provided MLIR module into a compiled artifact.
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
|
@ -7,7 +7,9 @@ from torch_mlir.ir import *
|
||||||
from torch_mlir.passmanager import *
|
from torch_mlir.passmanager import *
|
||||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||||
|
RefBackendLinalgOnTensorsBackend,
|
||||||
|
)
|
||||||
|
|
||||||
from .abc import TosaBackend
|
from .abc import TosaBackend
|
||||||
|
|
||||||
|
@ -17,7 +19,8 @@ __all__ = [
|
||||||
|
|
||||||
# The pipeline of func.func passes that lower the TOSA backend contract to the
|
# The pipeline of func.func passes that lower the TOSA backend contract to the
|
||||||
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||||
TOSA_TO_LINALG_FUNC_PIPELINE = ",".join([
|
TOSA_TO_LINALG_FUNC_PIPELINE = ",".join(
|
||||||
|
[
|
||||||
# TOSA legalization may emit tosa.const() ops. These are legalized
|
# TOSA legalization may emit tosa.const() ops. These are legalized
|
||||||
# by tosa-to-arith to arith.constants. This mechanical transformation
|
# by tosa-to-arith to arith.constants. This mechanical transformation
|
||||||
# must be done prior to TOSA-to-LinAlg so that the latter does not fail.
|
# must be done prior to TOSA-to-LinAlg so that the latter does not fail.
|
||||||
|
@ -33,7 +36,8 @@ TOSA_TO_LINALG_FUNC_PIPELINE = ",".join([
|
||||||
"tosa-to-linalg",
|
"tosa-to-linalg",
|
||||||
"tosa-to-tensor",
|
"tosa-to-tensor",
|
||||||
"tosa-to-arith",
|
"tosa-to-arith",
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LinalgOnTensorsTosaBackend(TosaBackend):
|
class LinalgOnTensorsTosaBackend(TosaBackend):
|
||||||
|
@ -60,7 +64,8 @@ class LinalgOnTensorsTosaBackend(TosaBackend):
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
f"builtin.module(func.func({TOSA_TO_LINALG_FUNC_PIPELINE}))",
|
f"builtin.module(func.func({TOSA_TO_LINALG_FUNC_PIPELINE}))",
|
||||||
"Lowering TOSA backend contract to Linalg-on-Tensors backend contract")
|
"Lowering TOSA backend contract to Linalg-on-Tensors backend contract",
|
||||||
|
)
|
||||||
|
|
||||||
return self.refbackend.compile(imported_module)
|
return self.refbackend.compile(imported_module)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
from torch_mlir.torchscript import TensorPlaceholder
|
from torch_mlir.torchscript import TensorPlaceholder
|
||||||
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||||
|
|
||||||
|
|
||||||
def convert_annotations_to_placeholders(forward_method):
|
def convert_annotations_to_placeholders(forward_method):
|
||||||
"""Converts the annotations on a forward method into tensor placeholders.
|
"""Converts the annotations on a forward method into tensor placeholders.
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ def convert_annotations_to_placeholders(forward_method):
|
||||||
for annotation in annotations[1:]:
|
for annotation in annotations[1:]:
|
||||||
if not annotation[2]:
|
if not annotation[2]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can only compile inputs annotated as having value semantics.")
|
"Can only compile inputs annotated as having value semantics."
|
||||||
|
)
|
||||||
placeholders.append(TensorPlaceholder(annotation[0], annotation[1]))
|
placeholders.append(TensorPlaceholder(annotation[0], annotation[1]))
|
||||||
return placeholders
|
return placeholders
|
||||||
|
|
|
@ -19,45 +19,52 @@ from lit.llvm.subst import FindTool
|
||||||
# Configuration file for the 'lit' test runner.
|
# Configuration file for the 'lit' test runner.
|
||||||
|
|
||||||
# name: The name of this test suite.
|
# name: The name of this test suite.
|
||||||
config.name = 'TORCH_MLIR_PT1'
|
config.name = "TORCH_MLIR_PT1"
|
||||||
|
|
||||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||||
|
|
||||||
# suffixes: A list of file extensions to treat as test files.
|
# suffixes: A list of file extensions to treat as test files.
|
||||||
config.suffixes = ['.mlir', '.py']
|
config.suffixes = [".mlir", ".py"]
|
||||||
|
|
||||||
# test_source_root: The root path where tests are located.
|
# test_source_root: The root path where tests are located.
|
||||||
config.test_source_root = os.path.dirname(__file__)
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
# test_exec_root: The root path where tests should be run.
|
# test_exec_root: The root path where tests should be run.
|
||||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||||
|
|
||||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
config.substitutions.append(("%PATH%", config.environment["PATH"]))
|
||||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
|
||||||
|
|
||||||
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
|
||||||
|
|
||||||
#llvm_config.use_default_substitutions()
|
# llvm_config.use_default_substitutions()
|
||||||
|
|
||||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||||
# directories.
|
# directories.
|
||||||
config.excludes = [
|
config.excludes = [
|
||||||
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
|
"Inputs",
|
||||||
'lit.cfg.py', 'lit.site.cfg.py'
|
"Examples",
|
||||||
|
"CMakeLists.txt",
|
||||||
|
"README.txt",
|
||||||
|
"LICENSE.txt",
|
||||||
|
"lit.cfg.py",
|
||||||
|
"lit.site.cfg.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
# test_source_root: The root path where tests are located.
|
# test_source_root: The root path where tests are located.
|
||||||
config.test_source_root = os.path.dirname(__file__)
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
# test_exec_root: The root path where tests should be run.
|
# test_exec_root: The root path where tests should be run.
|
||||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test")
|
||||||
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin")
|
||||||
|
|
||||||
# Tweak the PATH to include the tools dir.
|
# Tweak the PATH to include the tools dir.
|
||||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
|
||||||
# Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree.
|
# Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree.
|
||||||
llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True)
|
llvm_config.with_environment(
|
||||||
|
"PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True
|
||||||
|
)
|
||||||
|
|
||||||
# On Windows the path to python could contains spaces in which case it needs to
|
# On Windows the path to python could contains spaces in which case it needs to
|
||||||
# be provided in quotes. This is the equivalent of how %python is setup in
|
# be provided in quotes. This is the equivalent of how %python is setup in
|
||||||
|
@ -65,16 +72,23 @@ llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'),
|
||||||
if "Windows" in config.host_os:
|
if "Windows" in config.host_os:
|
||||||
config.python_executable = '"%s"' % (config.python_executable)
|
config.python_executable = '"%s"' % (config.python_executable)
|
||||||
|
|
||||||
tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir, config.torch_mlir_obj_root]
|
tool_dirs = [
|
||||||
|
config.standalone_tools_dir,
|
||||||
|
config.llvm_tools_dir,
|
||||||
|
config.torch_mlir_obj_root,
|
||||||
|
]
|
||||||
tools = [
|
tools = [
|
||||||
'torch-mlir-opt',
|
"torch-mlir-opt",
|
||||||
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
|
ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"),
|
||||||
]
|
]
|
||||||
|
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
|
||||||
if config.enable_bindings_python:
|
if config.enable_bindings_python:
|
||||||
llvm_config.with_environment('PYTHONPATH', [
|
llvm_config.with_environment(
|
||||||
os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'),
|
"PYTHONPATH",
|
||||||
|
[
|
||||||
|
os.path.join(config.torch_mlir_python_packages_dir, "torch_mlir"),
|
||||||
],
|
],
|
||||||
append_path=True)
|
append_path=True,
|
||||||
|
)
|
||||||
|
|
|
@ -20,18 +20,26 @@ goofy_lib = torch.library.Library("goofy", "DEF")
|
||||||
goofy_lib.define("identity(Tensor t) -> Tensor")
|
goofy_lib.define("identity(Tensor t) -> Tensor")
|
||||||
goofy_lib.impl("identity", identity)
|
goofy_lib.impl("identity", identity)
|
||||||
|
|
||||||
|
|
||||||
def goofy〇identity〡shape(t: List[int]) -> List[int]:
|
def goofy〇identity〡shape(t: List[int]) -> List[int]:
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
|
def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
|
||||||
t_rank, t_dtype = t_rank_dtype
|
t_rank, t_dtype = t_rank_dtype
|
||||||
return t_dtype
|
return t_dtype
|
||||||
|
|
||||||
|
|
||||||
def goofy〇identity〡has_value_semantics() -> None:
|
def goofy〇identity〡has_value_semantics() -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
extra_library = [
|
extra_library = [
|
||||||
goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics]
|
goofy〇identity〡shape,
|
||||||
|
goofy〇identity〡dtype,
|
||||||
|
goofy〇identity〡has_value_semantics,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CustomOpExampleModule(torch.nn.Module):
|
class CustomOpExampleModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -52,6 +60,7 @@ class CustomOpExampleModule(torch.nn.Module):
|
||||||
mod = CustomOpExampleModule()
|
mod = CustomOpExampleModule()
|
||||||
mod.eval()
|
mod.eval()
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
mod = CustomOpExampleModule()
|
mod = CustomOpExampleModule()
|
||||||
mod.eval()
|
mod.eval()
|
||||||
|
@ -66,6 +75,7 @@ def run():
|
||||||
|
|
||||||
print(module)
|
print(module)
|
||||||
|
|
||||||
|
|
||||||
run()
|
run()
|
||||||
|
|
||||||
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue