mirror of https://github.com/llvm/torch-mlir
Port the bulk of the remaining code to torch-mlir
This leaves no real code outside torch-mlir. This also renames the "npcomp backend contract" to "linalg on tensors backend contract" as the name of the abstraction layer that RefBackend (IREE too) accepts.pull/330/head
parent
aa10ec66a7
commit
404bd74ddf
|
@ -203,8 +203,7 @@ add_custom_target(check-npcomp)
|
|||
add_custom_target(check-npcomp-all)
|
||||
add_dependencies(check-npcomp-all
|
||||
check-npcomp
|
||||
check-npcomp-python
|
||||
check-torch-mlir
|
||||
check-torch-mlir-all
|
||||
)
|
||||
|
||||
add_subdirectory(lib)
|
||||
|
|
|
@ -66,6 +66,12 @@ add_subdirectory(include)
|
|||
add_subdirectory(lib)
|
||||
add_subdirectory(tools)
|
||||
|
||||
add_custom_target(check-torch-mlir-all)
|
||||
add_dependencies(check-torch-mlir-all
|
||||
check-torch-mlir
|
||||
check-torch-mlir-python
|
||||
)
|
||||
|
||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
# If parent projects want to configure where to place the python packages,
|
||||
# respect that.
|
||||
|
|
|
@ -23,8 +23,10 @@ cmake -GNinja -B"$build_dir" "$llvm_project_dir/llvm" \
|
|||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$project_dir" \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD=host
|
||||
|
||||
#-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
#
|
||||
|
||||
cd "$build_dir"
|
||||
ninja tools/torch-mlir/all check-torch-mlir
|
||||
ninja tools/torch-mlir/all check-torch-mlir-all
|
||||
|
|
|
@ -27,9 +27,9 @@ from fairseq.sequence_generator import SequenceGenerator
|
|||
from fairseq.tasks.fairseq_task import LegacyFairseqTask
|
||||
from fairseq import utils
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
DEFAULT_TEST_VOCAB_SIZE = 100
|
||||
|
|
@ -14,7 +14,7 @@ fi
|
|||
venv_dir=$1
|
||||
serialized_test_dir=$2
|
||||
here="$(realpath $(dirname $0))"
|
||||
npcomp_src_root="$here/../../"
|
||||
torch_mlir_src_root="$here/../../"
|
||||
|
||||
mkdir -p $venv_dir
|
||||
mkdir -p $serialized_test_dir
|
||||
|
@ -22,7 +22,7 @@ python3 -m venv $venv_dir
|
|||
source $venv_dir/bin/activate
|
||||
python3 -m pip install fairseq fvcore sacremoses subword-nmt
|
||||
|
||||
cd "$npcomp_src_root"
|
||||
cd "$torch_mlir_src_root"
|
||||
export PYTHONPATH=${PYTHONPATH-}
|
||||
source "$npcomp_src_root/.env"
|
||||
source "$torch_mlir_src_root/.env"
|
||||
python3 -m build_tools.torchscript_e2e_heavydep_tests.main --output_dir=$serialized_test_dir
|
|
@ -8,9 +8,9 @@ import pickle
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript.e2e_test.framework import SerializableTest, generate_golden_trace
|
||||
from npcomp_torchscript.annotations import extract_serializable_annotations
|
||||
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate_golden_trace
|
||||
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
||||
|
||||
from . import basic_mt
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
# generate the .env file with default options.
|
||||
#
|
||||
# For arbitrary build/install directories, set the env variables:
|
||||
# - NPCOMP_BUILD_DIR
|
||||
# - TORCH_MLIR_BUILD_DIR
|
||||
|
||||
portable_realpath() {
|
||||
# Create the directory if needed so that the `cd` doesn't fail.
|
||||
|
@ -10,12 +10,12 @@ portable_realpath() {
|
|||
}
|
||||
|
||||
td="$(portable_realpath $(dirname $0)/..)"
|
||||
build_dir="$(portable_realpath "${NPCOMP_BUILD_DIR:-$td/build}")"
|
||||
python_packages_dir="$build_dir/python_packages"
|
||||
build_dir="$(portable_realpath "${TORCH_MLIR_BUILD_DIR:-$td/build}")"
|
||||
python_packages_dir="$build_dir/tools/torch-mlir/python_packages"
|
||||
|
||||
write_env_file() {
|
||||
echo "Updating $build_dir/.env file"
|
||||
echo "PYTHONPATH=\"$(portable_realpath "$python_packages_dir/npcomp_core"):$(portable_realpath "$python_packages_dir/torch_mlir")\"" > "$build_dir/.env"
|
||||
echo "PYTHONPATH=\"$(portable_realpath "$python_packages_dir/torch_mlir")\"" > "$build_dir/.env"
|
||||
if ! cp "$build_dir/.env" "$td/.env"; then
|
||||
echo "WARNING: Failed to write $td/.env"
|
||||
fi
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
class BatchNorm1DModule(torch.nn.Module):
|
|
@ -3,9 +3,9 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import torch
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# TODO: Support scalar !torch.int/!torch.float variants. Add support to
|
||||
# ReduceOpVariants to implement them in terms of the tensor-only variants +
|
|
@ -8,16 +8,16 @@ import pickle
|
|||
import re
|
||||
import sys
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.framework import run_tests
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
||||
|
||||
# Available test configs.
|
||||
from npcomp_torchscript_e2e_test_configs import (
|
||||
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||
from torch_mlir_e2e_test.torchscript.configs import (
|
||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||
)
|
||||
|
||||
from npcomp.compiler.pytorch.backend.refbackend import RefBackendNpcompBackend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
||||
from .xfail_sets import XFAIL_SETS
|
||||
|
||||
|
@ -43,7 +43,7 @@ def _get_argparse():
|
|||
default='refbackend',
|
||||
help=f'''
|
||||
Meaning of options:
|
||||
"refbackend": run through npcomp's RefBackend.
|
||||
"refbackend": run through torch-mlir's RefBackend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
''')
|
||||
|
@ -58,7 +58,7 @@ Regular expression specifying which tests to include in this run.
|
|||
The directory containing serialized pre-built tests.
|
||||
Right now, these are additional tests which require heavy Python dependencies
|
||||
to generate (or cannot even be generated with the version of PyTorch used by
|
||||
npcomp).
|
||||
torch-mlir).
|
||||
See `build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh`
|
||||
for more information on building these artifacts.
|
||||
''')
|
||||
|
@ -69,7 +69,7 @@ def main():
|
|||
|
||||
# Find the selected config.
|
||||
if args.config == 'refbackend':
|
||||
config = NpcompBackendTestConfig(RefBackendNpcompBackend())
|
||||
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
elif args.config == 'native_torch':
|
||||
config = NativeTorchTestConfig()
|
||||
elif args.config == 'torchscript':
|
|
@ -5,9 +5,9 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -5,9 +5,9 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -5,9 +5,9 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestUtils
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -12,13 +12,13 @@
|
|||
XFAIL_SETS = {}
|
||||
|
||||
# Lists of tests that fail to even reach the backends.
|
||||
# These represent further work needed in npcomp to lower them properly
|
||||
# These represent further work needed in torch-mlir to lower them properly
|
||||
# to the backend contract.
|
||||
_common_npcomp_lowering_xfails = {
|
||||
_common_torch_mlir_lowering_xfails = {
|
||||
'QuantizedMLP_basic',
|
||||
}
|
||||
|
||||
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails | {
|
||||
XFAIL_SETS['refbackend'] = _common_torch_mlir_lowering_xfails | {
|
||||
# The first test in the e2e test batch would fail with SystemError: null
|
||||
# argument to internal routine. Might be some issue with refbackend.
|
||||
'MmModule_basic',
|
|
@ -17,12 +17,12 @@ def TorchConversion_Dialect : Dialect {
|
|||
let cppNamespace = "::mlir::torch::TorchConversion";
|
||||
let description = [{
|
||||
This dialect contains ops and transforms for converting from the Torch
|
||||
backend contract to the npcomp backend contract.
|
||||
backend contract to the linalg-on-tensors backend contract.
|
||||
|
||||
This mainly consists of converting ops and types from `torch` dialect
|
||||
to the mix of dialects of the npcomp backend contract, such as tensor
|
||||
ops being converted linalg-on-tensors and !torch.float being converted to
|
||||
`f64`.
|
||||
to the mix of dialects of the linalg-on-tensors backend contract, such as
|
||||
tensor ops being converted linalg-on-tensors and `!torch.vtensor` being
|
||||
converted to the builtin `tensor` type.
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ namespace TorchConversion {
|
|||
void getBackendTypeConversionDependentDialects(DialectRegistry ®istry);
|
||||
|
||||
/// Set up the provided ConversionTarget and TypeConverter for converting
|
||||
/// from `torch` dialect types to the types along the npcomp backend boundary
|
||||
/// (which currently consist only of builtin types).
|
||||
/// from `torch` dialect types to the types along the linalg-on-tensors backend
|
||||
/// boundary (which currently consist only of builtin types).
|
||||
void setupBackendTypeConversion(ConversionTarget &target,
|
||||
TypeConverter &typeConverter);
|
||||
} // namespace TorchConversion
|
||||
|
|
|
@ -19,8 +19,9 @@ namespace torch {
|
|||
namespace TorchConversion {
|
||||
|
||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
|
||||
void createTorchScriptToNpcompBackendPipeline(
|
||||
/// TorchScript import into the form expected by
|
||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
||||
void createTorchScriptToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ def VerifyInvariantsBeforeBackendLowering
|
|||
"mlir::torch::TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()";
|
||||
let description = [{
|
||||
This pass checks any invariants needed by the process of lowering the
|
||||
`torch` dialect to the npcomp backend contract.
|
||||
`torch` dialect to the linalg-on-tensors backend contract.
|
||||
|
||||
The most important invariant is that all tensors should be ranked and have
|
||||
a known dtype. It is useful to catch this early because it usually
|
||||
|
@ -52,7 +52,7 @@ def FinalizingBackendTypeConversion
|
|||
}];
|
||||
}
|
||||
|
||||
def VerifyLinalgOnTensorsBackendContract : Pass<"npcomp-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
||||
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg
|
|||
TorchToLinalg.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToLinalg
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
|
|
@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToSCF
|
|||
TorchToSCF.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToSCF
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToSCF
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
|
|
@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStd
|
|||
TorchToStd.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToStd
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStd
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
|
|
@ -3,7 +3,7 @@ add_mlir_dialect_library(TorchMLIRTorchConversionDialect
|
|||
TorchConversionOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion
|
||||
|
||||
DEPENDS
|
||||
MLIRTorchConversionOpsIncGen
|
||||
|
|
|
@ -5,7 +5,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
VerifyLinalgOnTensorsBackendContract.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion/Transforms
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRTorchConversionPassIncGen
|
||||
|
|
|
@ -33,16 +33,16 @@ namespace {
|
|||
void mlir::torch::registerTorchConversionPasses() {
|
||||
::registerPasses();
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-to-npcomp-backend-pipeline",
|
||||
"Pipeline lowering torch object graph to npcomp backend format.",
|
||||
mlir::torch::TorchConversion::createTorchScriptToNpcompBackendPipeline);
|
||||
"torchscript-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering torch object graph to linalg-on-tensors backend format.",
|
||||
mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline);
|
||||
}
|
||||
|
||||
void mlir::torch::TorchConversion::createTorchScriptToNpcompBackendPipeline(
|
||||
void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
|
||||
// Conversion to the npcomp backend contract starts from the Torch backend
|
||||
// contract.
|
||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||
// backend contract.
|
||||
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
||||
|
||||
// Check some invariants to catch errors in a clear way.
|
||||
|
@ -68,8 +68,8 @@ void mlir::torch::TorchConversion::createTorchScriptToNpcompBackendPipeline(
|
|||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the npcomp
|
||||
// backend contract.
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// linalg-on-tensors backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
|
|
|
@ -72,8 +72,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||
// doesn't unnecessarily spew out the entire module.
|
||||
emitError(module.getLoc())
|
||||
<< "Module does not conform to npcomp's backend contract. See "
|
||||
"dialect conversion legality information above.";
|
||||
<< "Module does not conform to the linalg-on-tensors backend contract. "
|
||||
"See dialect conversion legality information above.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ include(AddMLIRPython)
|
|||
# argument.
|
||||
set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir")
|
||||
|
||||
|
||||
# We vendor our own MLIR instance in the `torch_mlir` namespace.
|
||||
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
|
||||
|
||||
|
@ -49,6 +50,7 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
|
|||
|
||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||
add_subdirectory(torch_mlir/dialects/torch/importer/jit_ir)
|
||||
add_subdirectory(torch_mlir_e2e_test)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
@ -93,4 +95,8 @@ add_mlir_python_modules(TorchMLIRPythonModules
|
|||
# Then it would "just work".
|
||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||
add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporter)
|
||||
# Build the E2E Tests (which depend on the JIT IR importer now).
|
||||
add_dependencies(TorchMLIRPythonModules TorchMLIRE2ETestPythonModules)
|
||||
endif()
|
||||
|
||||
add_subdirectory(test)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
configure_lit_site_cfg(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||
MAIN_CONFIG
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
|
||||
)
|
||||
|
||||
set(TEST_DEPENDS
|
||||
FileCheck count not
|
||||
torch-mlir-opt
|
||||
TorchMLIRPythonModules
|
||||
)
|
||||
|
||||
add_lit_testsuite(check-torch-mlir-python "Running the torch-mlir Python regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${TEST_DEPENDS}
|
||||
)
|
||||
set_target_properties(check-torch-mlir-python PROPERTIES FOLDER "Tests")
|
||||
|
||||
add_lit_testsuites(TORCH_MLIR_PYTHON ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TEST_DEPENDS})
|
||||
add_dependencies(check-torch-mlir-all check-torch-mlir-python)
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.annotations import annotate_args, export
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
||||
|
|
@ -18,14 +18,14 @@ from lit.llvm.subst import FindTool
|
|||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'NPCOMP_PYTHON'
|
||||
config.name = 'TORCH_MLIR_PYTHON'
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
if 'TEST_SRC_PATH' in os.environ:
|
||||
config.environment['TEST_SRC_PATH'] = os.environ['TEST_SRC_PATH']
|
||||
|
||||
# path to our python operation library
|
||||
config.environment['TEST_BUILD_PATH'] = os.path.join(config.npcomp_obj_root)
|
||||
config.environment['TEST_BUILD_PATH'] = os.path.join(config.torch_mlir_obj_root)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.py']
|
||||
|
@ -34,7 +34,7 @@ config.suffixes = ['.py']
|
|||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.npcomp_obj_root, 'test')
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
|
||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||
|
@ -54,22 +54,20 @@ config.excludes = ['lit.cfg.py', 'Inputs', 'Examples', 'CMakeLists.txt', 'README
|
|||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.npcomp_obj_root, 'test')
|
||||
config.npcomp_tools_dir = os.path.join(config.npcomp_obj_root, 'bin')
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
config.torch_mlir_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
npcomp_python_dir = "python" if config.npcomp_built_standalone else "tools/npcomp/python"
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.npcomp_python_packages_dir, 'npcomp_core'),
|
||||
os.path.join(config.torch_mlir_python_packages_dir, 'torch_mlir'),
|
||||
],
|
||||
append_path=True)
|
||||
|
||||
|
||||
tool_dirs = [config.npcomp_tools_dir, config.llvm_tools_dir]
|
||||
tool_dirs = [config.torch_mlir_tools_dir, config.llvm_tools_dir]
|
||||
tools = [
|
||||
'npcomp-opt',
|
||||
'torch-mlir-opt',
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
|
@ -35,10 +35,8 @@ config.host_ldflags = '@HOST_LDFLAGS@'
|
|||
config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@"
|
||||
config.llvm_host_triple = '@LLVM_HOST_TRIPLE@'
|
||||
config.host_arch = "@HOST_ARCH@"
|
||||
config.npcomp_src_root = "@CMAKE_SOURCE_DIR@"
|
||||
config.npcomp_obj_root = "@CMAKE_BINARY_DIR@"
|
||||
config.npcomp_built_standalone = bool("@NPCOMP_BUILT_STANDALONE@")
|
||||
config.npcomp_python_packages_dir = "@MLIR_NPCOMP_PYTHON_PACKAGES_DIR@"
|
||||
config.torch_mlir_src_root = "@CMAKE_SOURCE_DIR@"
|
||||
config.torch_mlir_obj_root = "@CMAKE_BINARY_DIR@"
|
||||
config.torch_mlir_python_packages_dir = "@TORCH_MLIR_PYTHON_PACKAGES_DIR@"
|
||||
|
||||
|
|
@ -6,10 +6,10 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||
from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
|
@ -6,10 +6,10 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||
from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
|
@ -8,10 +8,10 @@ from typing import List, Tuple, Dict
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||
from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig
|
||||
|
||||
# CHECK: FAIL - "ErroneousModule_basic"
|
||||
|
|
@ -8,10 +8,10 @@ from typing import List, Tuple, Dict
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||
from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class NonTensorValuesModule(torch.nn.Module):
|
|
@ -6,10 +6,10 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||
from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
|
@ -6,10 +6,10 @@
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||
from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig
|
||||
|
||||
class Submodule2(torch.nn.Module):
|
||||
def __init__(self):
|
|
@ -17,7 +17,7 @@ from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator
|
|||
# to be expressed conveniently and gives clearer error reports when
|
||||
# the annotations aren't acceptable.
|
||||
|
||||
# This module is kept separate from npcomp_torchscript.annotations so that
|
||||
# This module is kept separate from torch_mlir_e2e_test.torchscript.annotations so that
|
||||
# we can use that module from code without C++ dependencies, which prevent us
|
||||
# from interfacing the test framework across environments.
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
declare_mlir_python_sources(TorchMLIRE2ETestPythonSources)
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRE2ETestPythonSources.Core
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRE2ETestPythonSources
|
||||
SOURCES_GLOB
|
||||
*.py
|
||||
)
|
||||
|
||||
add_mlir_python_modules(TorchMLIRE2ETestPythonModules
|
||||
ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir_e2e_test"
|
||||
INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir_e2e_test"
|
||||
DECLARED_SOURCES TorchMLIRE2ETestPythonSources
|
||||
)
|
|
@ -7,11 +7,11 @@ from typing import TypeVar
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp.ir import Module
|
||||
from torch_mlir.ir import Module
|
||||
|
||||
# A type shared between the result of `NpcompBackend.compile` and the input
|
||||
# to `NpcompBackend.load`. Each backend will likely have a different definition
|
||||
# of this type.
|
||||
# A type shared between the result of `LinalgOnTensorsBackend.compile` and the
|
||||
# input to `LinalgOnTensorsBackend.load`. Each backend will likely have a
|
||||
# different definition of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
|
@ -20,14 +20,14 @@ CompiledArtifact = TypeVar('CompiledArtifact')
|
|||
Invoker = TypeVar('Invoker')
|
||||
|
||||
|
||||
class NpcompBackend(abc.ABC):
|
||||
"""The interface to an npcomp backend.
|
||||
class LinalgOnTensorsBackend(abc.ABC):
|
||||
"""The interface to an linalg-on-tensors backend.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
||||
The module adheres to the npcomp backend contract
|
||||
The module adheres to the linalg-on-tensors backend contract
|
||||
(see the VerifyLinalgOnTensorsBackendContract pass).
|
||||
|
||||
The compiled artifact can be any type, but must be correctly
|
|
@ -13,10 +13,10 @@ from torch_mlir.runtime import *
|
|||
import torch_mlir.all_passes_registration
|
||||
import torch_mlir.dialects.torch
|
||||
|
||||
from .abc import NpcompBackend
|
||||
from .abc import LinalgOnTensorsBackend
|
||||
|
||||
__all__ = [
|
||||
"RefBackendNpcompBackend",
|
||||
"RefBackendLinalgOnTensorsBackend",
|
||||
]
|
||||
|
||||
|
||||
|
@ -87,7 +87,7 @@ LOWERING_PIPELINE = ",".join([
|
|||
])
|
||||
|
||||
|
||||
class RefBackendNpcompBackend(NpcompBackend):
|
||||
class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
||||
"""Main entry-point for the backend."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -107,13 +107,10 @@ class RefBackendNpcompBackend(NpcompBackend):
|
|||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
# Go through a string because we are briding two separate CAPI's.
|
||||
# TODO: Remove after npcomp's mlir is deleted in favor of torch_mlir.
|
||||
with Context() as ctx:
|
||||
module = Module.parse(str(imported_module))
|
||||
with imported_module.context:
|
||||
pm = PassManager.parse(LOWERING_PIPELINE)
|
||||
pm.run(module)
|
||||
return module
|
||||
pm.run(imported_module)
|
||||
return imported_module
|
||||
|
||||
def load(self, module) -> RefBackendInvoker:
|
||||
"""Loads a compiled artifact into the runtime."""
|
|
@ -0,0 +1,8 @@
|
|||
## Declare the sources of the Python module.
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.TorchScriptE2ETest
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES_GLOB
|
||||
dialects/torch/e2e_test/torchscript/*.py
|
||||
)
|
|
@ -26,7 +26,7 @@ TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = '_torch_mlir_arg_annotations'
|
|||
|
||||
|
||||
def export(fn):
|
||||
"""Decorator that tells the npcomp compiler that a method is exported.
|
||||
"""Decorator that tells the torch-mlir compiler that a method is exported.
|
||||
|
||||
By default, no methods are exported, which is very important for
|
||||
the compiler, because otherwise most Torch programs consist of a sea
|
||||
|
@ -49,7 +49,7 @@ ArgAnnotation = Tuple[List[int], torch.dtype]
|
|||
# TODO: Replace with py3 extended argument annotations when available.
|
||||
# See https://www.python.org/dev/peps/pep-0593/
|
||||
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
|
||||
"""Decorator that tells the npcomp compiler information about arguments.
|
||||
"""Decorator that tells the torch-mlir compiler information about arguments.
|
||||
|
||||
The `annotations` should be a list of the same length as the number of
|
||||
argument to the method (including `self`). Each list entry is either:
|
|
@ -2,6 +2,6 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .npcomp_backend import NpcompBackendTestConfig
|
||||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
|
@ -13,11 +13,9 @@ import torch
|
|||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
||||
import npcomp
|
||||
from npcomp.passmanager import PassManager
|
||||
from npcomp.compiler.pytorch.backend import refbackend
|
||||
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
|
||||
from npcomp_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
def _recursively_convert_to_numpy(o: Any):
|
||||
if isinstance(o, torch.Tensor):
|
||||
|
@ -55,13 +53,13 @@ def _recursively_convert_from_numpy(o: Any):
|
|||
return o
|
||||
raise Exception(f"Unexpected Python function output: {o}")
|
||||
|
||||
class NpcompBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with npcomp.
|
||||
class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||
|
||||
This class handles all the common lowering that npcomp does before reaching
|
||||
its backends.
|
||||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the linalg-on-tensors abstraction level.
|
||||
"""
|
||||
def __init__(self, backend: NpcompBackend):
|
||||
def __init__(self, backend: LinalgOnTensorsBackend):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
||||
|
@ -80,7 +78,7 @@ class NpcompBackendTestConfig(TestConfig):
|
|||
mb.import_module(scripted._c, class_annotator)
|
||||
except Exception as e:
|
||||
raise Exception(f"""
|
||||
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with:
|
||||
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||
Exception:
|
||||
{e}
|
||||
Diagnostics:
|
||||
|
@ -89,22 +87,15 @@ Diagnostics:
|
|||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
# The torch-mlir python code is built against its own aggregate CAPI.
|
||||
# The npcomp python module is built against our own.
|
||||
# So we need to transport it across those as a string.
|
||||
with npcomp.ir.Context() as ctx:
|
||||
npcomp.register_all_dialects(ctx)
|
||||
module = npcomp.ir.Module.parse(str(mb.module))
|
||||
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
asm_for_error_report = mb.module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
pipeline_str = "torchscript-to-npcomp-backend-pipeline"
|
||||
pipeline_str = "torchscript-to-linalg-on-tensors-backend-pipeline"
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context:
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
pm.run(module)
|
||||
pm.run(mb.module)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
|
@ -119,27 +110,27 @@ Diagnostics:
|
|||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics:
|
||||
torch-mlir TorchScript Object Graph IR -> linalg-on-tensors backend IR lowering failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
Error can be reproduced with:
|
||||
$ npcomp-opt -{pipeline_str} {filename}
|
||||
$ torch-mlir-opt -{pipeline_str} {filename}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
asm_for_error_report = mb.module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
return self.backend.compile(module)
|
||||
return self.backend.compile(mb.module)
|
||||
except Exception as e:
|
||||
filename = os.path.join(tempfile.gettempdir(),
|
||||
scripted.original_name + '.mlir')
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||
torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||
## Exception:
|
||||
{e}
|
||||
|
|
@ -7,7 +7,7 @@ from typing import Any
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
class NativeTorchTestConfig(TestConfig):
|
|
@ -7,7 +7,7 @@ from typing import Any
|
|||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
class TorchScriptTestConfig(TestConfig):
|
|
@ -27,7 +27,7 @@ import pickle
|
|||
|
||||
import torch
|
||||
|
||||
from ..annotations import apply_serializable_annotations
|
||||
from .annotations import apply_serializable_annotations
|
||||
|
||||
|
||||
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
|
||||
|
@ -81,7 +81,7 @@ class TestConfig(abc.ABC):
|
|||
users to suit their own needs. We provide a few configs out of the box
|
||||
in the `configs` submodule of this package, but those are intended
|
||||
to be for basic inspiration and enough for our own testing.
|
||||
Backends to npcomp will likely have more elaborate TestConfig's, such
|
||||
Backends to torch-mlir will likely have more elaborate TestConfig's, such
|
||||
as `compile` being "compile for such-and-such DSP with these vectorization
|
||||
cost model flags" and `run` being "connect to Android phone with
|
||||
device ID 1234 and upload a program to run on it's DSP core, and also set
|
||||
|
@ -95,7 +95,7 @@ class TestConfig(abc.ABC):
|
|||
wild and wonderful set of possible configurations that we cannot predict.
|
||||
"""
|
||||
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
|
||||
# We can have a helper class NpcompBackendTestConfig which does that.
|
||||
# We can have a helper class LinalgOnTensorsBackendTestConfig which does that.
|
||||
@abc.abstractmethod
|
||||
def compile(self, program: torch.nn.Module) -> CompiledArtifact:
|
||||
"""Compile the provided torch.nn.Module into a compiled artifact"""
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt -npcomp-verify-linalg-on-tensors-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-verify-linalg-on-tensors-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
|
||||
|
||||
// CHECK: func @mm
|
||||
func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
@ -21,7 +21,7 @@ func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
|
||||
// Basic check of error reporting.
|
||||
|
||||
// expected-error@+1 {{Module does not conform to npcomp's backend contract.}}
|
||||
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
|
||||
module {
|
||||
func @disallowed() {
|
||||
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
|
||||
|
@ -44,7 +44,7 @@ module {
|
|||
// in an understandable way, such as suggesting a particular place where
|
||||
// a shape annotation is needed.
|
||||
|
||||
// expected-error@+1 {{Module does not conform to npcomp's backend contract.}}
|
||||
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
|
||||
module {
|
||||
func @disallowed(%arg0: tensor<?x!numpy.any_dtype>) -> tensor<?x!numpy.any_dtype> {
|
||||
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE.pytorch for license information.
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
# TODO: Fix ODR violation on non-static cl::opt in LLVM
|
||||
# `cl::opt<FunctionSummary::ForceSummaryHotnessType, true>`.
|
||||
# This causes double free on global dtors on exiting the program.
|
||||
# The FileCheck still passes though.
|
||||
# RUN: (%PYTHON %s || true) | FileCheck %s
|
||||
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- npcomp-lsp-server.cpp - MLIR Language Server -------------*- C++ -*-===//
|
||||
//===- torch-mlir-lsp-server.cpp - MLIR Language Server ---------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
|
|
@ -96,7 +96,7 @@ declare_mlir_python_sources(NPCOMPTorchSupportPythonSources
|
|||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_support"
|
||||
SOURCES_GLOB
|
||||
npcomp_torchscript/*.py
|
||||
npcomp_torchscript_e2e_test_configs/*.py
|
||||
torch_mlir_e2e_test.torchscript.configs/*.py
|
||||
)
|
||||
|
||||
add_mlir_python_modules(NPCOMPTorchSupportPythonModules
|
||||
|
@ -114,5 +114,3 @@ add_dependencies(NPCOMPTorchSupportPythonModules
|
|||
################################################################################
|
||||
# Recurse into the tests.
|
||||
################################################################################
|
||||
|
||||
add_subdirectory(test)
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._mlir_libs._npcomp import register_all_dialects
|
|
@ -1,28 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import os
|
||||
import string
|
||||
import sys
|
||||
|
||||
__all__ = ["debug", "debug_enabled", "enable"]
|
||||
|
||||
_ENABLED = "NPCOMP_DEBUG" in os.environ
|
||||
_formatter = string.Formatter()
|
||||
|
||||
|
||||
def enable():
|
||||
global _ENABLED
|
||||
_ENABLED = True
|
||||
|
||||
|
||||
def debug_enabled():
|
||||
return _ENABLED
|
||||
|
||||
|
||||
def debug(format_string, *args, **kwargs):
|
||||
if not _ENABLED:
|
||||
return
|
||||
formatted = _formatter.vformat(format_string, args, kwargs)
|
||||
print("DEBUG:", formatted, file=sys.stderr)
|
|
@ -1,22 +0,0 @@
|
|||
configure_lit_site_cfg(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||
MAIN_CONFIG
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
|
||||
)
|
||||
|
||||
set(TEST_DEPENDS
|
||||
FileCheck count not
|
||||
npcomp-opt
|
||||
NPCOMPTorchSupportPythonModules
|
||||
TorchMLIRPythonModules
|
||||
)
|
||||
|
||||
add_lit_testsuite(check-npcomp-python "Running the npcomp-python regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${TEST_DEPENDS}
|
||||
)
|
||||
set_target_properties(check-npcomp-python PROPERTIES FOLDER "Tests")
|
||||
|
||||
add_lit_testsuites(NPCOMP_PYTHON ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TEST_DEPENDS})
|
||||
add_dependencies(check-npcomp-all check-npcomp-python)
|
Loading…
Reference in New Issue