mirror of https://github.com/llvm/torch-mlir
Add a basic TOSA E2E backend.
We lower through linalg-on-tensors and use RefBackend to run it. This adds enough support for a "tanh" op. Adding more ops should be fairly mechanical now that things are wired up. Run with: ``` ./tools/torchscript_e2e_test.sh -c tosa ``` The backend structure is very similar to linalg-on-tensors based E2E backends and is a nice parallel (see `tosa_backend.py`). Actually, this forced a nice refactoring to the layering here. We removed `torchscript-module-to-linalg-on-tensors-backend-pipeline` and instead require separately running ``` torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline ``` This highlights the step that lowers to the "torch backend contract" of cleaned up `torch` dialect ops is a critical step in the lowering. Going forward, that is the key load-bearing contract of the torch-mlir project, not the linalg-on-tensors backend contract. Recommended review order: - `TorchToTosa.cpp` / `TorchToTosa/basic.mlir` - `python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py` and the new `utils.py` file there. - `python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py` and `abc.py` in that directory for the TOSA backend e2e interface. - other misc mechanical changespull/359/head
parent
df12cc0c37
commit
0c5c84d63d
|
@ -61,6 +61,7 @@ jobs:
|
|||
cd $GITHUB_WORKSPACE
|
||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||
python -m e2e_testing.torchscript.main --config=refbackend -v
|
||||
python -m e2e_testing.torchscript.main --config=tosa -v
|
||||
|
||||
# TODO: Only build packages in full Release mode.
|
||||
# On the other hand, having assertions on isn't too bad of an idea at this
|
||||
|
|
|
@ -15,12 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
|||
|
||||
# Available test configs.
|
||||
from torch_mlir_e2e_test.torchscript.configs import (
|
||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig
|
||||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
|
||||
from .xfail_sets import XFAIL_SETS, COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
||||
|
@ -35,7 +36,7 @@ from . import elementwise
|
|||
from . import reduction
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'external']
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('-c', '--config',
|
||||
choices=config_choices,
|
||||
|
@ -43,6 +44,7 @@ def _get_argparse():
|
|||
help=f'''
|
||||
Meaning of options:
|
||||
"refbackend": run through torch-mlir's RefBackend.
|
||||
"tosa": run through torch-mlir's default TOSA backend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
"external": use an external backend, specified by the `--external-backend` option.
|
||||
|
@ -77,16 +79,27 @@ for more information on building these artifacts.
|
|||
def main():
|
||||
args = _get_argparse().parse_args()
|
||||
|
||||
all_tests = list(GLOBAL_TEST_REGISTRY)
|
||||
if args.serialized_test_dir:
|
||||
for root, dirs, files in os.walk(args.serialized_test_dir):
|
||||
for filename in files:
|
||||
with open(os.path.join(root, filename), 'rb') as f:
|
||||
all_tests.append(pickle.load(f).as_test())
|
||||
all_test_unique_names = set(test.unique_name for test in all_tests)
|
||||
|
||||
# Find the selected config.
|
||||
if args.config == 'refbackend':
|
||||
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
xfail_set = XFAIL_SETS['refbackend']
|
||||
xfail_set = REFBACKEND_XFAIL_SET
|
||||
if args.config == 'tosa':
|
||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||
elif args.config == 'native_torch':
|
||||
config = NativeTorchTestConfig()
|
||||
xfail_set = XFAIL_SETS['native_torch']
|
||||
xfail_set = {}
|
||||
elif args.config == 'torchscript':
|
||||
config = TorchScriptTestConfig()
|
||||
xfail_set = XFAIL_SETS['torchscript']
|
||||
xfail_set = {}
|
||||
elif args.config == 'external':
|
||||
with open(args.external_config, 'r') as f:
|
||||
code = compile(f.read(), args.external_config, 'exec')
|
||||
|
@ -106,13 +119,6 @@ def main():
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
all_tests = list(GLOBAL_TEST_REGISTRY)
|
||||
if args.serialized_test_dir:
|
||||
for root, dirs, files in os.walk(args.serialized_test_dir):
|
||||
for filename in files:
|
||||
with open(os.path.join(root, filename), 'rb') as f:
|
||||
all_tests.append(pickle.load(f).as_test())
|
||||
|
||||
# Find the selected tests, and emit a diagnostic if none are found.
|
||||
tests = [
|
||||
test for test in all_tests
|
||||
|
|
|
@ -10,17 +10,15 @@
|
|||
# (this includes down into lower parts of the stack, where a side table
|
||||
# might be used to keep more elaborate sets of testing configurations).
|
||||
|
||||
XFAIL_SETS = {}
|
||||
|
||||
# Lists of tests that fail to even reach the backends.
|
||||
# These represent further work needed in torch-mlir to lower them properly
|
||||
# to the backend contract.
|
||||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||
'QuantizedMLP_basic',
|
||||
"QuantizedMLP_basic",
|
||||
}
|
||||
|
||||
XFAIL_SETS['refbackend'] = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
XFAIL_SETS['torchscript'] = {}
|
||||
|
||||
XFAIL_SETS['native_torch'] = {}
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {"ElementwiseUnaryModule_basic"}
|
||||
|
|
|
@ -65,7 +65,7 @@ mlir_module.dump()
|
|||
|
||||
# Compile the torch MLIR and execute the compiled program
|
||||
with mlir_module.context:
|
||||
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline')
|
||||
pm = PassManager.parse('torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(mlir_module)
|
||||
|
||||
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
|
||||
|
|
|
@ -153,7 +153,7 @@
|
|||
" mb.import_module(scripted._c, class_annotator)\n",
|
||||
"\n",
|
||||
" ## Lower the MLIR from TorchScript to RefBackend, passing through linalg-on-tensors.\n",
|
||||
" pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n",
|
||||
" pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n",
|
||||
" pm.run(mb.module)\n",
|
||||
"\n",
|
||||
" ## Invoke RefBackend to compile to compiled artifact form.\n",
|
||||
|
|
|
@ -51,7 +51,7 @@ mlir_module.dump()
|
|||
print(mlir_module.operation.verify())
|
||||
|
||||
with mlir_module.context:
|
||||
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
|
||||
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(mlir_module)
|
||||
|
||||
print("\n\nLOWERED MLIR")
|
||||
|
|
|
@ -109,7 +109,7 @@ mb.import_module(recursivescriptmodule._c, class_annotator)
|
|||
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
|
||||
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(mb.module)
|
||||
|
||||
compiled = backend.compile(mb.module)
|
||||
|
|
|
@ -105,4 +105,13 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
|||
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
|
||||
let summary = "Convert Torch ops to TOSA ops";
|
||||
let description = [{
|
||||
This pass assumes that TOSA ops are responsible for emitting error
|
||||
guards in case of shape mismatches.
|
||||
}];
|
||||
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_PASSES
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTosaPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
|
@ -33,13 +33,13 @@ struct TorchLoweringPipelineOptions
|
|||
|
||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||
/// TorchScript import into the form expected by torch-verify-backend-contract.
|
||||
void createTorchScriptToTorchBackendPipeline(
|
||||
void createTorchScriptModuleToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
||||
/// with the torch and aten dialects and mutable arrays and converts it to
|
||||
/// the form required by torch-verify-backend-contract.
|
||||
void createGlobalizedModuleToTorchBackendPipeline(
|
||||
void createTorchFunctionToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||
|
|
|
@ -19,17 +19,15 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace TorchConversion {
|
||||
|
||||
/// Creates a pipeline that lowers the object graph IR that is given by a
|
||||
/// TorchScript jit.ScriptModule into the form expected by
|
||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
||||
void createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
||||
/// Creates a pipeline that lowers from the torch backend contract to the
|
||||
/// linalg-on-tensors backend contract.
|
||||
void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that lowers the object graph IR that is given by a
|
||||
/// TorchScript jit.ScriptFunction into the form expected by
|
||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
||||
void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
||||
/// Creates a pipeline that lowers from the torch backend contract to the
|
||||
/// TOSA backend contract.
|
||||
void createTorchBackendToTosaBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
|
||||
|
@ -44,6 +42,8 @@ createFinalizingBackendTypeConversionPass();
|
|||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyLinalgOnTensorsBackendContractPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
|
||||
|
||||
} // namespace TorchConversion
|
||||
|
||||
/// Registers all Torch transformation passes.
|
||||
|
|
|
@ -58,4 +58,9 @@ def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-
|
|||
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
||||
}
|
||||
|
||||
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_subdirectory(TorchToLinalg)
|
||||
add_subdirectory(TorchToSCF)
|
||||
add_subdirectory(TorchToStd)
|
||||
add_subdirectory(TorchToTosa)
|
||||
|
||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
||||
#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS)
|
||||
|
@ -18,5 +19,6 @@ add_mlir_library(TorchMLIRConversionPasses
|
|||
TorchMLIRTorchToLinalg
|
||||
TorchMLIRTorchToSCF
|
||||
TorchMLIRTorchToStd
|
||||
TorchMLIRTorchToTosa
|
||||
#${torch_mlir_conversion_libs}
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToTosa
|
||||
TorchToTosa.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTosa
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToTosa)
|
|
@ -0,0 +1,76 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class ConvertAtenTanhOp : public OpConversionPattern<AtenTanhOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenTanhOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
AtenTanhOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
class ConvertTorchToTosa
|
||||
: public ConvertTorchToTosaBase<ConvertTorchToTosa> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tosa::TosaDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<tosa::TosaDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<AtenTanhOp>();
|
||||
patterns.add<ConvertAtenTanhOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::torch::createConvertTorchToTosaPass() {
|
||||
return std::make_unique<ConvertTorchToTosa>();
|
||||
}
|
|
@ -23,16 +23,16 @@ namespace {
|
|||
void mlir::torch::registerTorchPasses() {
|
||||
::registerPasses();
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-to-torch-backend-pipeline",
|
||||
"torchscript-module-to-torch-backend-pipeline",
|
||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||
mlir::torch::Torch::createTorchScriptToTorchBackendPipeline);
|
||||
mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torch-globalized-module-to-torch-backend-pipeline",
|
||||
"Pipeline lowering a globalized Torch program to Torch backend form.",
|
||||
mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline);
|
||||
"torch-function-to-torch-backend-pipeline",
|
||||
"Pipeline lowering a Torch function to Torch backend form.",
|
||||
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
||||
void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
// When we import TorchScript IR, we import their entire "compilation unit",
|
||||
// which can contain numerous functions unrelated to the current program,
|
||||
|
@ -58,10 +58,10 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
|||
// TODO: Improve shape inference.
|
||||
pm.addPass(createInlinerPass());
|
||||
|
||||
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||
createTorchFunctionToTorchBackendPipeline(pm, options);
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
||||
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
// General considerations: As a matter of bring-up, we are simultaneously
|
||||
// building out the frontend pipeline and also co-developing the backend
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
|
||||
void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(
|
||||
|
|
|
@ -3,6 +3,8 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
Passes.cpp
|
||||
VerifyInvariantsBeforeBackendLowering.cpp
|
||||
VerifyLinalgOnTensorsBackendContract.cpp
|
||||
VerifyTosaBackendContract.cpp
|
||||
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
|
@ -34,16 +34,18 @@ namespace {
|
|||
void mlir::torch::registerTorchConversionPasses() {
|
||||
::registerPasses();
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-module-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.",
|
||||
TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline);
|
||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
||||
"contract.",
|
||||
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-function-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.",
|
||||
TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline);
|
||||
"torch-backend-to-tosa-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to TOSA backend "
|
||||
"contract.",
|
||||
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
||||
}
|
||||
|
||||
static void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
// Check some invariants to catch errors in a clear way.
|
||||
pm.addPass(
|
||||
|
@ -80,20 +82,29 @@ static void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
|
||||
}
|
||||
|
||||
void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
||||
void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
// Check some invariants to catch errors in a clear way.
|
||||
pm.addPass(
|
||||
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||
|
||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||
// backend contract.
|
||||
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
||||
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// TOSA backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
|
||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||
// backend contract.
|
||||
Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
||||
// Verify that we have lowered to the form that TOSA backends
|
||||
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
||||
// correct form.
|
||||
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
|
||||
namespace {
|
||||
class VerifyTosaBackendContractPass
|
||||
: public VerifyTosaBackendContractBase<VerifyTosaBackendContractPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
auto module = getOperation();
|
||||
TypeConverter converter;
|
||||
converter.addConversion([](TensorType type) -> Type {
|
||||
if (BaseMemRefType::isValidElementType(type.getElementType()))
|
||||
return type;
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
||||
|
||||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
|
||||
// Basic scalar operations.
|
||||
target.addLegalDialect<tosa::TosaDialect>();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||
// doesn't unnecessarily spew out the entire module.
|
||||
emitError(module.getLoc())
|
||||
<< "Module does not conform to the TOSA backend contract. "
|
||||
"See dialect conversion legality information above.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() {
|
||||
return std::make_unique<VerifyTosaBackendContractPass>();
|
||||
}
|
|
@ -6,3 +6,4 @@
|
|||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
||||
from .tosa_backend import TosaBackendTestConfig
|
||||
|
|
|
@ -12,47 +12,15 @@ import tempfile
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
from .utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
convert_torchscript_module_to_torch_backend_contract_mlir,
|
||||
run_pipeline_with_repro_report
|
||||
)
|
||||
|
||||
def _recursively_convert_to_numpy(o: Any):
|
||||
if isinstance(o, torch.Tensor):
|
||||
return o.numpy()
|
||||
if isinstance(o, tuple):
|
||||
return tuple(_recursively_convert_to_numpy(x) for x in o)
|
||||
if isinstance(o, list):
|
||||
return [_recursively_convert_to_numpy(x) for x in o]
|
||||
if isinstance(o, dict):
|
||||
return {k: _recursively_convert_to_numpy(v) for k, v in o.items()}
|
||||
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
||||
if isinstance(o, str):
|
||||
return o
|
||||
if isinstance(o, float):
|
||||
return o
|
||||
if isinstance(o, int):
|
||||
return o
|
||||
raise Exception(f"Unexpected Python function input: {o}")
|
||||
|
||||
def _recursively_convert_from_numpy(o: Any):
|
||||
if isinstance(o, np.ndarray):
|
||||
return torch.from_numpy(o)
|
||||
if isinstance(o, tuple):
|
||||
return tuple(_recursively_convert_from_numpy(x) for x in o)
|
||||
if isinstance(o, list):
|
||||
return [_recursively_convert_from_numpy(x) for x in o]
|
||||
if isinstance(o, dict):
|
||||
return {k: _recursively_convert_from_numpy(v) for k, v in o.items()}
|
||||
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
||||
if isinstance(o, str):
|
||||
return o
|
||||
if isinstance(o, float):
|
||||
return o
|
||||
if isinstance(o, int):
|
||||
return o
|
||||
raise Exception(f"Unexpected Python function output: {o}")
|
||||
|
||||
class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||
|
@ -65,73 +33,28 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
|||
self.backend = backend
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
mb = ModuleBuilder()
|
||||
scripted = torch.jit.script(program)
|
||||
class_annotator = ClassAnnotator()
|
||||
|
||||
extract_annotations(program, scripted, class_annotator)
|
||||
module = convert_torchscript_module_to_torch_backend_contract_mlir(
|
||||
program)
|
||||
|
||||
# TODO: Find a way to make each of these calls own its own
|
||||
# "debuggable error report" situation.
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
# Import the TorchScript module to MLIR
|
||||
mb.import_module(scripted._c, class_annotator)
|
||||
except Exception as e:
|
||||
raise Exception(f"""
|
||||
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||
Exception:
|
||||
{e}
|
||||
Diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"Lower Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
program.__class__.__name__)
|
||||
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = mb.module.operation.get_asm(
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline"
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
pm.run(mb.module)
|
||||
return self.backend.compile(module)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||
# RAM fs, which increases worries about capacity).
|
||||
# - don't have colliding filenames (hard to do without cluttering
|
||||
# up /tmp)
|
||||
# - if we do have have colliding filenames, writes should at least
|
||||
# avoid being racy.
|
||||
filename = os.path.join(tempfile.gettempdir(),
|
||||
scripted.original_name + '.mlir')
|
||||
program.__class__.__name__ + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
torch-mlir TorchScript Object Graph IR -> linalg-on-tensors backend IR lowering failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
Error can be reproduced with:
|
||||
$ torch-mlir-opt -{pipeline_str} {filename}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = mb.module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
return self.backend.compile(mb.module)
|
||||
except Exception as e:
|
||||
filename = os.path.join(tempfile.gettempdir(),
|
||||
scripted.original_name + '.mlir')
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||
Linalg-on-Tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||
## Exception:
|
||||
{e}
|
||||
|
||||
|
@ -148,9 +71,9 @@ torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name
|
|||
backend_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
numpy_inputs = _recursively_convert_to_numpy(item.inputs)
|
||||
numpy_inputs = recursively_convert_to_numpy(item.inputs)
|
||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||
output = _recursively_convert_from_numpy(outputs)
|
||||
output = recursively_convert_from_numpy(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
from io import StringIO
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
from .utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
convert_torchscript_module_to_torch_backend_contract_mlir,
|
||||
run_pipeline_with_repro_report
|
||||
)
|
||||
|
||||
|
||||
class TosaBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||
|
||||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the linalg-on-tensors abstraction level.
|
||||
"""
|
||||
def __init__(self, backend: TosaBackend):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
|
||||
module = convert_torchscript_module_to_torch_backend_contract_mlir(
|
||||
program)
|
||||
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"torch-backend-to-tosa-backend-pipeline",
|
||||
"Lower Torch Backend IR -> TOSA Backend IR",
|
||||
program.__class__.__name__)
|
||||
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
return self.backend.compile(module)
|
||||
except Exception as e:
|
||||
filename = os.path.join(tempfile.gettempdir(),
|
||||
program.__class__.__name__ + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
TOSA Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||
## Exception:
|
||||
{e}
|
||||
|
||||
## Stderr:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
## Input IR has been saved in {filename}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
backend_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
numpy_inputs = recursively_convert_to_numpy(item.inputs)
|
||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||
output = recursively_convert_from_numpy(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
return result
|
|
@ -0,0 +1,129 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
from io import StringIO
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
def recursively_convert_to_numpy(o: Any):
|
||||
if isinstance(o, torch.Tensor):
|
||||
return o.numpy()
|
||||
if isinstance(o, tuple):
|
||||
return tuple(recursively_convert_to_numpy(x) for x in o)
|
||||
if isinstance(o, list):
|
||||
return [recursively_convert_to_numpy(x) for x in o]
|
||||
if isinstance(o, dict):
|
||||
return {k: recursively_convert_to_numpy(v) for k, v in o.items()}
|
||||
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
||||
if isinstance(o, str):
|
||||
return o
|
||||
if isinstance(o, float):
|
||||
return o
|
||||
if isinstance(o, int):
|
||||
return o
|
||||
raise Exception(f"Unexpected Python function input: {o}")
|
||||
|
||||
def recursively_convert_from_numpy(o: Any):
|
||||
if isinstance(o, np.ndarray):
|
||||
return torch.from_numpy(o)
|
||||
if isinstance(o, tuple):
|
||||
return tuple(recursively_convert_from_numpy(x) for x in o)
|
||||
if isinstance(o, list):
|
||||
return [recursively_convert_from_numpy(x) for x in o]
|
||||
if isinstance(o, dict):
|
||||
return {k: recursively_convert_from_numpy(v) for k, v in o.items()}
|
||||
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
||||
if isinstance(o, str):
|
||||
return o
|
||||
if isinstance(o, float):
|
||||
return o
|
||||
if isinstance(o, int):
|
||||
return o
|
||||
raise Exception(f"Unexpected Python function output: {o}")
|
||||
|
||||
|
||||
def run_pipeline_with_repro_report(module,
|
||||
pipeline: str,
|
||||
description: str,
|
||||
module_name: str):
|
||||
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context:
|
||||
pm = PassManager.parse(pipeline)
|
||||
pm.run(module)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||
# RAM fs, which increases worries about capacity).
|
||||
# - don't have colliding filenames (hard to do without cluttering
|
||||
# up /tmp)
|
||||
# - if we do have have colliding filenames, writes should at least
|
||||
# avoid being racy.
|
||||
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
{description} failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
Error can be reproduced with:
|
||||
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
def convert_torchscript_module_to_torch_backend_contract_mlir(program: torch.nn.Module):
|
||||
"""Perform common lowering from TorchScript to Torch MLIR
|
||||
|
||||
Returns an MLIR module that satisfies the Torch backend contract.
|
||||
"""
|
||||
mb = ModuleBuilder()
|
||||
scripted = torch.jit.script(program)
|
||||
class_annotator = ClassAnnotator()
|
||||
|
||||
extract_annotations(program, scripted, class_annotator)
|
||||
|
||||
|
||||
# TODO: Find a way to make each of these calls own its own
|
||||
# "debuggable error report" situation.
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
# Import the TorchScript module to MLIR
|
||||
mb.import_module(scripted._c, class_annotator)
|
||||
except Exception as e:
|
||||
raise Exception(f"""
|
||||
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||
Exception:
|
||||
{e}
|
||||
Diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"torchscript-module-to-torch-backend-pipeline",
|
||||
"Lowering TorchScript Object Graph IR -> Torch Backend IR",
|
||||
program.__class__.__name__)
|
||||
|
||||
return mb.module
|
|
@ -0,0 +1,46 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import abc
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.ir import Module
|
||||
|
||||
# A type shared between the result of `TosaBackend.compile` and the
|
||||
# input to `TosaBackend.load`. Each backend will likely have a
|
||||
# different definition of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
# that uniformly translates the `x.method(...)` interface expected of
|
||||
# Torch modules into appropriate lower-level operations.
|
||||
Invoker = TypeVar('Invoker')
|
||||
|
||||
|
||||
class TosaBackend(abc.ABC):
|
||||
"""The interface to an TOSA backend.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
||||
The module adheres to the TOSA backend contract
|
||||
(see the VerifyTosaBackendContract pass).
|
||||
|
||||
The compiled artifact can be any type, but must be correctly
|
||||
interpreted by the `load` method.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, artifact: CompiledArtifact) -> Invoker:
|
||||
"""Load the compiled artifact into a uniformly invokable form.
|
||||
|
||||
The compiled artifact is the result of a previous call to `compile`.
|
||||
|
||||
See the description of `Invoker` for the requirements on the returned
|
||||
type.
|
||||
"""
|
|
@ -0,0 +1,48 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from torch_mlir.ir import *
|
||||
from torch_mlir.passmanager import *
|
||||
# Imported for side effects.
|
||||
import torch_mlir.all_passes_registration
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
||||
from .abc import TosaBackend
|
||||
|
||||
__all__ = [
|
||||
"LinalgOnTensorsTosaBackend",
|
||||
]
|
||||
|
||||
class LinalgOnTensorsTosaBackend(TosaBackend):
|
||||
"""Main entry-point for the linalg-on-tensors based TOSA backend.
|
||||
|
||||
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module that satisfied the TOSA backend contract.
|
||||
|
||||
Args:
|
||||
imported_module: The MLIR module consisting of funcs in the TOSA
|
||||
dialect.
|
||||
Returns:
|
||||
An opaque, backend specific compiled artifact object that can be
|
||||
passed to `load`.
|
||||
"""
|
||||
# TODO: Error/repro reporting.
|
||||
# We should store the program name as the symbol name of the MLIR
|
||||
# module so we don't have to have access to the original program for it.
|
||||
with imported_module.context:
|
||||
pm = PassManager.parse("builtin.func(tosa-to-linalg-on-tensors)")
|
||||
pm.run(imported_module)
|
||||
return self.refbackend.compile(imported_module)
|
||||
|
||||
def load(self, module):
|
||||
"""Loads a compiled artifact into the runtime."""
|
||||
return self.refbackend.load(module)
|
|
@ -0,0 +1,12 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.tanh$basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
|
@ -36,7 +36,7 @@ module {
|
|||
//
|
||||
// The reporting we inherit from dialect conversion is not precise.
|
||||
// For example, here we want it to explicitly call out that
|
||||
// `tensor<?x!numpy.any_dtype>` is the problem here, which suggests
|
||||
// `!torch.tensor` is the problem here, which suggests
|
||||
// that type inference didn't succeed, or insufficient type information
|
||||
// was available.
|
||||
//
|
||||
|
@ -46,8 +46,8 @@ module {
|
|||
|
||||
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
|
||||
module {
|
||||
func @disallowed(%arg0: tensor<?x!numpy.any_dtype>) -> tensor<?x!numpy.any_dtype> {
|
||||
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
|
||||
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||
return %arg0 : tensor<?x!numpy.any_dtype>
|
||||
return %arg0 : !torch.tensor
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
// RUN: torch-mlir-opt -torch-verify-tosa-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
|
||||
|
||||
// CHECK: func @tanh
|
||||
func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = "tosa.tanh"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Basic check of error reporting.
|
||||
|
||||
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
|
||||
module {
|
||||
func @disallowed() {
|
||||
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
|
||||
"unknown_dialect.unknown_op"() : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Improve these errors to give more exact reporting.
|
||||
//
|
||||
// The reporting we inherit from dialect conversion is not precise.
|
||||
// For example, here we want it to explicitly call out that
|
||||
// `!torch.tensor` is the problem here, which suggests
|
||||
// that type inference didn't succeed, or insufficient type information
|
||||
// was available.
|
||||
//
|
||||
// Ultimately, the output of this pass needs to be conveyed to the user
|
||||
// in an understandable way, such as suggesting a particular place where
|
||||
// a shape annotation is needed.
|
||||
|
||||
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
|
||||
module {
|
||||
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
|
||||
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||
return %arg0 : !torch.tensor
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue