Add a basic TOSA E2E backend.

We lower through linalg-on-tensors and use RefBackend to run it.
This adds enough support for a "tanh" op. Adding more ops should be
fairly mechanical now that things are wired up. Run with:
```
./tools/torchscript_e2e_test.sh -c tosa
```

The backend structure is very similar to linalg-on-tensors based E2E
backends and is a nice parallel (see `tosa_backend.py`). Actually, this
forced a nice refactoring to the layering here. We removed
`torchscript-module-to-linalg-on-tensors-backend-pipeline` and instead
require separately running
```
torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline
```
This highlights the step that lowers to the "torch backend contract"
of cleaned up `torch` dialect ops is a critical step in the lowering.
Going forward, that is the key load-bearing contract of the torch-mlir
project, not the linalg-on-tensors backend contract.

Recommended review order:
- `TorchToTosa.cpp` / `TorchToTosa/basic.mlir`
- `python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py` and
  the new `utils.py` file there.
- `python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py` and
  `abc.py` in that directory for the TOSA backend e2e interface.
- other misc mechanical changes
pull/359/head
Sean Silva 2021-10-08 02:07:03 +00:00
parent df12cc0c37
commit 0c5c84d63d
34 changed files with 660 additions and 167 deletions

View File

@ -61,6 +61,7 @@ jobs:
cd $GITHUB_WORKSPACE cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=refbackend -v python -m e2e_testing.torchscript.main --config=refbackend -v
python -m e2e_testing.torchscript.main --config=tosa -v
# TODO: Only build packages in full Release mode. # TODO: Only build packages in full Release mode.
# On the other hand, having assertions on isn't too bad of an idea at this # On the other hand, having assertions on isn't too bad of an idea at this

View File

@ -15,12 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
# Available test configs. # Available test configs.
from torch_mlir_e2e_test.torchscript.configs import ( from torch_mlir_e2e_test.torchscript.configs import (
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig
) )
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import XFAIL_SETS, COMMON_TORCH_MLIR_LOWERING_XFAILS from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS
# Import tests to register them in the global registry. # Import tests to register them in the global registry.
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking # Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
@ -35,7 +36,7 @@ from . import elementwise
from . import reduction from . import reduction
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'external'] config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config', parser.add_argument('-c', '--config',
choices=config_choices, choices=config_choices,
@ -43,6 +44,7 @@ def _get_argparse():
help=f''' help=f'''
Meaning of options: Meaning of options:
"refbackend": run through torch-mlir's RefBackend. "refbackend": run through torch-mlir's RefBackend.
"tosa": run through torch-mlir's default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
"external": use an external backend, specified by the `--external-backend` option. "external": use an external backend, specified by the `--external-backend` option.
@ -77,16 +79,27 @@ for more information on building these artifacts.
def main(): def main():
args = _get_argparse().parse_args() args = _get_argparse().parse_args()
all_tests = list(GLOBAL_TEST_REGISTRY)
if args.serialized_test_dir:
for root, dirs, files in os.walk(args.serialized_test_dir):
for filename in files:
with open(os.path.join(root, filename), 'rb') as f:
all_tests.append(pickle.load(f).as_test())
all_test_unique_names = set(test.unique_name for test in all_tests)
# Find the selected config. # Find the selected config.
if args.config == 'refbackend': if args.config == 'refbackend':
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = XFAIL_SETS['refbackend'] xfail_set = REFBACKEND_XFAIL_SET
if args.config == 'tosa':
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
elif args.config == 'native_torch': elif args.config == 'native_torch':
config = NativeTorchTestConfig() config = NativeTorchTestConfig()
xfail_set = XFAIL_SETS['native_torch'] xfail_set = {}
elif args.config == 'torchscript': elif args.config == 'torchscript':
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
xfail_set = XFAIL_SETS['torchscript'] xfail_set = {}
elif args.config == 'external': elif args.config == 'external':
with open(args.external_config, 'r') as f: with open(args.external_config, 'r') as f:
code = compile(f.read(), args.external_config, 'exec') code = compile(f.read(), args.external_config, 'exec')
@ -106,13 +119,6 @@ def main():
) )
sys.exit(1) sys.exit(1)
all_tests = list(GLOBAL_TEST_REGISTRY)
if args.serialized_test_dir:
for root, dirs, files in os.walk(args.serialized_test_dir):
for filename in files:
with open(os.path.join(root, filename), 'rb') as f:
all_tests.append(pickle.load(f).as_test())
# Find the selected tests, and emit a diagnostic if none are found. # Find the selected tests, and emit a diagnostic if none are found.
tests = [ tests = [
test for test in all_tests test for test in all_tests

View File

@ -10,17 +10,15 @@
# (this includes down into lower parts of the stack, where a side table # (this includes down into lower parts of the stack, where a side table
# might be used to keep more elaborate sets of testing configurations). # might be used to keep more elaborate sets of testing configurations).
XFAIL_SETS = {}
# Lists of tests that fail to even reach the backends. # Lists of tests that fail to even reach the backends.
# These represent further work needed in torch-mlir to lower them properly # These represent further work needed in torch-mlir to lower them properly
# to the backend contract. # to the backend contract.
COMMON_TORCH_MLIR_LOWERING_XFAILS = { COMMON_TORCH_MLIR_LOWERING_XFAILS = {
'QuantizedMLP_basic', "QuantizedMLP_basic",
} }
XFAIL_SETS['refbackend'] = COMMON_TORCH_MLIR_LOWERING_XFAILS REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
XFAIL_SETS['torchscript'] = {} # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
XFAIL_SETS['native_torch'] = {} TOSA_PASS_SET = {"ElementwiseUnaryModule_basic"}

View File

@ -65,7 +65,7 @@ mlir_module.dump()
# Compile the torch MLIR and execute the compiled program # Compile the torch MLIR and execute the compiled program
with mlir_module.context: with mlir_module.context:
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline') pm = PassManager.parse('torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
pm.run(mlir_module) pm.run(mlir_module)
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE") print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")

View File

@ -153,7 +153,7 @@
" mb.import_module(scripted._c, class_annotator)\n", " mb.import_module(scripted._c, class_annotator)\n",
"\n", "\n",
" ## Lower the MLIR from TorchScript to RefBackend, passing through linalg-on-tensors.\n", " ## Lower the MLIR from TorchScript to RefBackend, passing through linalg-on-tensors.\n",
" pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n", " pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n",
" pm.run(mb.module)\n", " pm.run(mb.module)\n",
"\n", "\n",
" ## Invoke RefBackend to compile to compiled artifact form.\n", " ## Invoke RefBackend to compile to compiled artifact form.\n",

View File

@ -51,7 +51,7 @@ mlir_module.dump()
print(mlir_module.operation.verify()) print(mlir_module.operation.verify())
with mlir_module.context: with mlir_module.context:
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline') pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
pm.run(mlir_module) pm.run(mlir_module)
print("\n\nLOWERED MLIR") print("\n\nLOWERED MLIR")

View File

@ -109,7 +109,7 @@ mb.import_module(recursivescriptmodule._c, class_annotator)
backend = refbackend.RefBackendLinalgOnTensorsBackend() backend = refbackend.RefBackendLinalgOnTensorsBackend()
with mb.module.context: with mb.module.context:
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline') pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
pm.run(mb.module) pm.run(mb.module)
compiled = backend.compile(mb.module) compiled = backend.compile(mb.module)

View File

@ -105,4 +105,13 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
} }
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
This pass assumes that TOSA ops are responsible for emitting error
guards in case of shape mismatches.
}];
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}
#endif // TORCHMLIR_CONVERSION_PASSES #endif // TORCHMLIR_CONVERSION_PASSES

View File

@ -0,0 +1,22 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTosaPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H

View File

@ -33,13 +33,13 @@ struct TorchLoweringPipelineOptions
/// Creates a pipeline that lowers the object graph IR that is produced by /// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by torch-verify-backend-contract. /// TorchScript import into the form expected by torch-verify-backend-contract.
void createTorchScriptToTorchBackendPipeline( void createTorchScriptModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options); OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers a flat list of funcs and global slots /// Creates a pipeline that lowers a flat list of funcs and global slots
/// with the torch and aten dialects and mutable arrays and converts it to /// with the torch and aten dialects and mutable arrays and converts it to
/// the form required by torch-verify-backend-contract. /// the form required by torch-verify-backend-contract.
void createGlobalizedModuleToTorchBackendPipeline( void createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options); OpPassManager &pm, const TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass(); std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();

View File

@ -19,17 +19,15 @@ namespace mlir {
namespace torch { namespace torch {
namespace TorchConversion { namespace TorchConversion {
/// Creates a pipeline that lowers the object graph IR that is given by a /// Creates a pipeline that lowers from the torch backend contract to the
/// TorchScript jit.ScriptModule into the form expected by /// linalg-on-tensors backend contract.
/// torch-verify-linalg-on-tensors-verify-backend-contract. void createTorchBackendToLinalgOnTensorsBackendPipeline(
void createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options); const torch::Torch::TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers the object graph IR that is given by a /// Creates a pipeline that lowers from the torch backend contract to the
/// TorchScript jit.ScriptFunction into the form expected by /// TOSA backend contract.
/// torch-verify-linalg-on-tensors-verify-backend-contract. void createTorchBackendToTosaBackendPipeline(
void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options); const torch::Torch::TorchLoweringPipelineOptions &options);
@ -44,6 +42,8 @@ createFinalizingBackendTypeConversionPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass(); createVerifyLinalgOnTensorsBackendContractPass();
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
} // namespace TorchConversion } // namespace TorchConversion
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.

View File

@ -58,4 +58,9 @@ def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
} }
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
}
#endif // TORCHMLIR_TORCHCONVERSION_PASSES #endif // TORCHMLIR_TORCHCONVERSION_PASSES

View File

@ -1,6 +1,7 @@
add_subdirectory(TorchToLinalg) add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToSCF)
add_subdirectory(TorchToStd) add_subdirectory(TorchToStd)
add_subdirectory(TorchToTosa)
# TODO: Automate this with add_torch_mlir_conversion_library. # TODO: Automate this with add_torch_mlir_conversion_library.
#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS) #get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS)
@ -18,5 +19,6 @@ add_mlir_library(TorchMLIRConversionPasses
TorchMLIRTorchToLinalg TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF TorchMLIRTorchToSCF
TorchMLIRTorchToStd TorchMLIRTorchToStd
TorchMLIRTorchToTosa
#${torch_mlir_conversion_libs} #${torch_mlir_conversion_libs}
) )

View File

@ -12,6 +12,7 @@
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration

View File

@ -22,7 +22,6 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------

View File

@ -19,7 +19,6 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
namespace { namespace {

View File

@ -20,7 +20,6 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------

View File

@ -0,0 +1,20 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchToTosa.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTosa
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchToTosa)

View File

@ -0,0 +1,76 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class ConvertAtenTanhOp : public OpConversionPattern<AtenTanhOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenTanhOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
AtenTanhOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self());
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
namespace {
class ConvertTorchToTosa
: public ConvertTorchToTosaBase<ConvertTorchToTosa> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
target.addIllegalOp<AtenTanhOp>();
patterns.add<ConvertAtenTanhOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::createConvertTorchToTosaPass() {
return std::make_unique<ConvertTorchToTosa>();
}

View File

@ -23,16 +23,16 @@ namespace {
void mlir::torch::registerTorchPasses() { void mlir::torch::registerTorchPasses() {
::registerPasses(); ::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-torch-backend-pipeline", "torchscript-module-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.", "Pipeline lowering TorchScript object graph IR to Torch backend form.",
mlir::torch::Torch::createTorchScriptToTorchBackendPipeline); mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-globalized-module-to-torch-backend-pipeline", "torch-function-to-torch-backend-pipeline",
"Pipeline lowering a globalized Torch program to Torch backend form.", "Pipeline lowering a Torch function to Torch backend form.",
mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline); mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
} }
void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline( void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) { OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// When we import TorchScript IR, we import their entire "compilation unit", // When we import TorchScript IR, we import their entire "compilation unit",
// which can contain numerous functions unrelated to the current program, // which can contain numerous functions unrelated to the current program,
@ -58,10 +58,10 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
// TODO: Improve shape inference. // TODO: Improve shape inference.
pm.addPass(createInlinerPass()); pm.addPass(createInlinerPass());
createGlobalizedModuleToTorchBackendPipeline(pm, options); createTorchFunctionToTorchBackendPipeline(pm, options);
} }
void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline( void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) { OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General considerations: As a matter of bring-up, we are simultaneously // General considerations: As a matter of bring-up, we are simultaneously
// building out the frontend pipeline and also co-developing the backend // building out the frontend pipeline and also co-developing the backend

View File

@ -21,7 +21,6 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch;
using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TorchConversion;
void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects( void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(

View File

@ -3,6 +3,8 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Passes.cpp Passes.cpp
VerifyInvariantsBeforeBackendLowering.cpp VerifyInvariantsBeforeBackendLowering.cpp
VerifyLinalgOnTensorsBackendContract.cpp VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms

View File

@ -16,11 +16,11 @@
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration
@ -34,16 +34,18 @@ namespace {
void mlir::torch::registerTorchConversionPasses() { void mlir::torch::registerTorchConversionPasses() {
::registerPasses(); ::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-module-to-linalg-on-tensors-backend-pipeline", "torch-backend-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.", "Pipeline lowering torch backend contract to linalg-on-tensors backend "
TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline); "contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-function-to-linalg-on-tensors-backend-pipeline", "torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.", "Pipeline lowering torch backend contract to TOSA backend "
TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline); "contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline);
} }
static void createTorchBackendToLinalgOnTensorsBackendPipeline( void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Check some invariants to catch errors in a clear way. // Check some invariants to catch errors in a clear way.
pm.addPass( pm.addPass(
@ -80,20 +82,29 @@ static void createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()); pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
} }
void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline( void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Check some invariants to catch errors in a clear way.
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
// Conversion to the linalg-on-tensors backend contract starts from the Torch pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
// backend contract.
Torch::createTorchScriptToTorchBackendPipeline(pm, options); if (options.optimize) {
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options); // Clean up any non-canonical code introduced above..
} pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline( pm.addNestedPass<FuncOp>(createCSEPass());
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { }
// Conversion to the linalg-on-tensors backend contract starts from the Torch // Finish the type conversion from `torch` types to the types of the
// backend contract. // TOSA backend contract.
Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options); pm.addNestedPass<FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that TOSA backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
} }

View File

@ -0,0 +1,62 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
namespace {
class VerifyTosaBackendContractPass
: public VerifyTosaBackendContractBase<VerifyTosaBackendContractPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto module = getOperation();
TypeConverter converter;
converter.addConversion([](TensorType type) -> Type {
if (BaseMemRefType::isValidElementType(type.getElementType()))
return type;
return nullptr;
});
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
// Basic scalar operations.
target.addLegalDialect<tosa::TosaDialect>();
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to the TOSA backend contract. "
"See dialect conversion legality information above.";
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() {
return std::make_unique<VerifyTosaBackendContractPass>();
}

View File

@ -6,3 +6,4 @@
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
from .native_torch import NativeTorchTestConfig from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig from .torchscript import TorchScriptTestConfig
from .tosa_backend import TosaBackendTestConfig

View File

@ -12,47 +12,15 @@ import tempfile
import numpy as np import numpy as np
import torch import torch
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
convert_torchscript_module_to_torch_backend_contract_mlir,
run_pipeline_with_repro_report
)
def _recursively_convert_to_numpy(o: Any):
if isinstance(o, torch.Tensor):
return o.numpy()
if isinstance(o, tuple):
return tuple(_recursively_convert_to_numpy(x) for x in o)
if isinstance(o, list):
return [_recursively_convert_to_numpy(x) for x in o]
if isinstance(o, dict):
return {k: _recursively_convert_to_numpy(v) for k, v in o.items()}
# No-op cases. Explicitly enumerated to avoid things sneaking through.
if isinstance(o, str):
return o
if isinstance(o, float):
return o
if isinstance(o, int):
return o
raise Exception(f"Unexpected Python function input: {o}")
def _recursively_convert_from_numpy(o: Any):
if isinstance(o, np.ndarray):
return torch.from_numpy(o)
if isinstance(o, tuple):
return tuple(_recursively_convert_from_numpy(x) for x in o)
if isinstance(o, list):
return [_recursively_convert_from_numpy(x) for x in o]
if isinstance(o, dict):
return {k: _recursively_convert_from_numpy(v) for k, v in o.items()}
# No-op cases. Explicitly enumerated to avoid things sneaking through.
if isinstance(o, str):
return o
if isinstance(o, float):
return o
if isinstance(o, int):
return o
raise Exception(f"Unexpected Python function output: {o}")
class LinalgOnTensorsBackendTestConfig(TestConfig): class LinalgOnTensorsBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with linalg-on-tensors. """Base class for TestConfig's that are implemented with linalg-on-tensors.
@ -65,73 +33,28 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
self.backend = backend self.backend = backend
def compile(self, program: torch.nn.Module) -> Any: def compile(self, program: torch.nn.Module) -> Any:
mb = ModuleBuilder()
scripted = torch.jit.script(program)
class_annotator = ClassAnnotator()
extract_annotations(program, scripted, class_annotator) module = convert_torchscript_module_to_torch_backend_contract_mlir(
program)
# TODO: Find a way to make each of these calls own its own run_pipeline_with_repro_report(
# "debuggable error report" situation. module,
try: "torch-backend-to-linalg-on-tensors-backend-pipeline",
sys.stderr = StringIO() "Lower Torch Backend IR -> Linalg-on-Tensors Backend IR",
# Import the TorchScript module to MLIR program.__class__.__name__)
mb.import_module(scripted._c, class_annotator)
except Exception as e:
raise Exception(f"""
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
Exception:
{e}
Diagnostics:
{sys.stderr.getvalue()}
""") from None
finally:
sys.stderr = sys.__stderr__
try: try:
sys.stderr = StringIO() sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm( asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True) large_elements_limit=10, enable_debug_info=True)
pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline" return self.backend.compile(module)
# Lower module in place to make it ready for compiler backends.
with mb.module.context:
pm = PassManager.parse(pipeline_str)
pm.run(mb.module)
except Exception as e: except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
# tests, this can be a big disk cost (also, /tmp/ is frequently a
# RAM fs, which increases worries about capacity).
# - don't have colliding filenames (hard to do without cluttering
# up /tmp)
# - if we do have have colliding filenames, writes should at least
# avoid being racy.
filename = os.path.join(tempfile.gettempdir(), filename = os.path.join(tempfile.gettempdir(),
scripted.original_name + '.mlir') program.__class__.__name__ + ".mlir")
with open(filename, 'w') as f: with open(filename, 'w') as f:
f.write(asm_for_error_report) f.write(asm_for_error_report)
raise Exception(f""" raise Exception(f"""
torch-mlir TorchScript Object Graph IR -> linalg-on-tensors backend IR lowering failed with the following diagnostics: Linalg-on-Tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
{sys.stderr.getvalue()}
Error can be reproduced with:
$ torch-mlir-opt -{pipeline_str} {filename}
""") from None
finally:
sys.stderr = sys.__stderr__
try:
sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
return self.backend.compile(mb.module)
except Exception as e:
filename = os.path.join(tempfile.gettempdir(),
scripted.original_name + '.mlir')
with open(filename, 'w') as f:
f.write(asm_for_error_report)
raise Exception(f"""
torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
## Exception: ## Exception:
{e} {e}
@ -148,9 +71,9 @@ torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name
backend_module = self.backend.load(artifact) backend_module = self.backend.load(artifact)
result: Trace = [] result: Trace = []
for item in trace: for item in trace:
numpy_inputs = _recursively_convert_to_numpy(item.inputs) numpy_inputs = recursively_convert_to_numpy(item.inputs)
outputs = getattr(backend_module, item.symbol)(*numpy_inputs) outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
output = _recursively_convert_from_numpy(outputs) output = recursively_convert_from_numpy(outputs)
result.append( result.append(
TraceItem(symbol=item.symbol, TraceItem(symbol=item.symbol,
inputs=item.inputs, inputs=item.inputs,

View File

@ -0,0 +1,81 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
from typing import Any
from io import StringIO
import os
import tempfile
import numpy as np
import torch
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
convert_torchscript_module_to_torch_backend_contract_mlir,
run_pipeline_with_repro_report
)
class TosaBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
This class handles all the common lowering that torch-mlir does before
reaching the linalg-on-tensors abstraction level.
"""
def __init__(self, backend: TosaBackend):
super().__init__()
self.backend = backend
def compile(self, program: torch.nn.Module) -> Any:
module = convert_torchscript_module_to_torch_backend_contract_mlir(
program)
run_pipeline_with_repro_report(
module,
"torch-backend-to-tosa-backend-pipeline",
"Lower Torch Backend IR -> TOSA Backend IR",
program.__class__.__name__)
try:
sys.stderr = StringIO()
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
return self.backend.compile(module)
except Exception as e:
filename = os.path.join(tempfile.gettempdir(),
program.__class__.__name__ + ".mlir")
with open(filename, 'w') as f:
f.write(asm_for_error_report)
raise Exception(f"""
TOSA Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
## Exception:
{e}
## Stderr:
{sys.stderr.getvalue()}
## Input IR has been saved in {filename}
""") from None
finally:
sys.stderr = sys.__stderr__
def run(self, artifact: Any, trace: Trace) -> Trace:
backend_module = self.backend.load(artifact)
result: Trace = []
for item in trace:
numpy_inputs = recursively_convert_to_numpy(item.inputs)
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
output = recursively_convert_from_numpy(outputs)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
return result

View File

@ -0,0 +1,129 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
from typing import Any
from io import StringIO
import os
import tempfile
import numpy as np
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
def recursively_convert_to_numpy(o: Any):
if isinstance(o, torch.Tensor):
return o.numpy()
if isinstance(o, tuple):
return tuple(recursively_convert_to_numpy(x) for x in o)
if isinstance(o, list):
return [recursively_convert_to_numpy(x) for x in o]
if isinstance(o, dict):
return {k: recursively_convert_to_numpy(v) for k, v in o.items()}
# No-op cases. Explicitly enumerated to avoid things sneaking through.
if isinstance(o, str):
return o
if isinstance(o, float):
return o
if isinstance(o, int):
return o
raise Exception(f"Unexpected Python function input: {o}")
def recursively_convert_from_numpy(o: Any):
if isinstance(o, np.ndarray):
return torch.from_numpy(o)
if isinstance(o, tuple):
return tuple(recursively_convert_from_numpy(x) for x in o)
if isinstance(o, list):
return [recursively_convert_from_numpy(x) for x in o]
if isinstance(o, dict):
return {k: recursively_convert_from_numpy(v) for k, v in o.items()}
# No-op cases. Explicitly enumerated to avoid things sneaking through.
if isinstance(o, str):
return o
if isinstance(o, float):
return o
if isinstance(o, int):
return o
raise Exception(f"Unexpected Python function output: {o}")
def run_pipeline_with_repro_report(module,
pipeline: str,
description: str,
module_name: str):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
try:
sys.stderr = StringIO()
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
# Lower module in place to make it ready for compiler backends.
with module.context:
pm = PassManager.parse(pipeline)
pm.run(module)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
# tests, this can be a big disk cost (also, /tmp/ is frequently a
# RAM fs, which increases worries about capacity).
# - don't have colliding filenames (hard to do without cluttering
# up /tmp)
# - if we do have have colliding filenames, writes should at least
# avoid being racy.
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
with open(filename, 'w') as f:
f.write(asm_for_error_report)
raise Exception(f"""
{description} failed with the following diagnostics:
{sys.stderr.getvalue()}
Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
""") from None
finally:
sys.stderr = sys.__stderr__
def convert_torchscript_module_to_torch_backend_contract_mlir(program: torch.nn.Module):
"""Perform common lowering from TorchScript to Torch MLIR
Returns an MLIR module that satisfies the Torch backend contract.
"""
mb = ModuleBuilder()
scripted = torch.jit.script(program)
class_annotator = ClassAnnotator()
extract_annotations(program, scripted, class_annotator)
# TODO: Find a way to make each of these calls own its own
# "debuggable error report" situation.
try:
sys.stderr = StringIO()
# Import the TorchScript module to MLIR
mb.import_module(scripted._c, class_annotator)
except Exception as e:
raise Exception(f"""
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
Exception:
{e}
Diagnostics:
{sys.stderr.getvalue()}
""") from None
finally:
sys.stderr = sys.__stderr__
run_pipeline_with_repro_report(
mb.module,
"torchscript-module-to-torch-backend-pipeline",
"Lowering TorchScript Object Graph IR -> Torch Backend IR",
program.__class__.__name__)
return mb.module

View File

@ -0,0 +1,46 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import abc
from typing import TypeVar
import torch
from torch_mlir.ir import Module
# A type shared between the result of `TosaBackend.compile` and the
# input to `TosaBackend.load`. Each backend will likely have a
# different definition of this type.
CompiledArtifact = TypeVar('CompiledArtifact')
# A wrapper around a backend-specific loaded program representation
# that uniformly translates the `x.method(...)` interface expected of
# Torch modules into appropriate lower-level operations.
Invoker = TypeVar('Invoker')
class TosaBackend(abc.ABC):
"""The interface to an TOSA backend.
"""
@abc.abstractmethod
def compile(self, module: Module) -> CompiledArtifact:
"""Compile the provided MLIR module into a compiled artifact.
The module adheres to the TOSA backend contract
(see the VerifyTosaBackendContract pass).
The compiled artifact can be any type, but must be correctly
interpreted by the `load` method.
"""
@abc.abstractmethod
def load(self, artifact: CompiledArtifact) -> Invoker:
"""Load the compiled artifact into a uniformly invokable form.
The compiled artifact is the result of a previous call to `compile`.
See the description of `Invoker` for the requirements on the returned
type.
"""

View File

@ -0,0 +1,48 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from torch_mlir.ir import *
from torch_mlir.passmanager import *
# Imported for side effects.
import torch_mlir.all_passes_registration
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from .abc import TosaBackend
__all__ = [
"LinalgOnTensorsTosaBackend",
]
class LinalgOnTensorsTosaBackend(TosaBackend):
"""Main entry-point for the linalg-on-tensors based TOSA backend.
This currently uses the linalg-on-tensors RefBackend for actual execution.
"""
def __init__(self):
super().__init__()
self.refbackend = RefBackendLinalgOnTensorsBackend()
def compile(self, imported_module: Module):
"""Compiles an imported module that satisfied the TOSA backend contract.
Args:
imported_module: The MLIR module consisting of funcs in the TOSA
dialect.
Returns:
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
# TODO: Error/repro reporting.
# We should store the program name as the symbol name of the MLIR
# module so we don't have to have access to the original program for it.
with imported_module.context:
pm = PassManager.parse("builtin.func(tosa-to-linalg-on-tensors)")
pm.run(imported_module)
return self.refbackend.compile(imported_module)
def load(self, module):
"""Loads a compiled artifact into the runtime."""
return self.refbackend.load(module)

View File

@ -0,0 +1,12 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @torch.aten.tanh$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

View File

@ -36,7 +36,7 @@ module {
// //
// The reporting we inherit from dialect conversion is not precise. // The reporting we inherit from dialect conversion is not precise.
// For example, here we want it to explicitly call out that // For example, here we want it to explicitly call out that
// `tensor<?x!numpy.any_dtype>` is the problem here, which suggests // `!torch.tensor` is the problem here, which suggests
// that type inference didn't succeed, or insufficient type information // that type inference didn't succeed, or insufficient type information
// was available. // was available.
// //
@ -46,8 +46,8 @@ module {
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}} // expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
module { module {
func @disallowed(%arg0: tensor<?x!numpy.any_dtype>) -> tensor<?x!numpy.any_dtype> { func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'std.return'}} // expected-error@+1 {{failed to legalize operation 'std.return'}}
return %arg0 : tensor<?x!numpy.any_dtype> return %arg0 : !torch.tensor
} }
} }

View File

@ -0,0 +1,42 @@
// RUN: torch-mlir-opt -torch-verify-tosa-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
// CHECK: func @tanh
func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.tanh"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
// Basic check of error reporting.
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
module {
func @disallowed() {
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
"unknown_dialect.unknown_op"() : () -> ()
return
}
}
// -----
// TODO: Improve these errors to give more exact reporting.
//
// The reporting we inherit from dialect conversion is not precise.
// For example, here we want it to explicitly call out that
// `!torch.tensor` is the problem here, which suggests
// that type inference didn't succeed, or insufficient type information
// was available.
//
// Ultimately, the output of this pass needs to be conveyed to the user
// in an understandable way, such as suggesting a particular place where
// a shape annotation is needed.
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
module {
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'std.return'}}
return %arg0 : !torch.tensor
}
}