mirror of https://github.com/llvm/torch-mlir
Add a basic TOSA E2E backend.
We lower through linalg-on-tensors and use RefBackend to run it. This adds enough support for a "tanh" op. Adding more ops should be fairly mechanical now that things are wired up. Run with: ``` ./tools/torchscript_e2e_test.sh -c tosa ``` The backend structure is very similar to linalg-on-tensors based E2E backends and is a nice parallel (see `tosa_backend.py`). Actually, this forced a nice refactoring to the layering here. We removed `torchscript-module-to-linalg-on-tensors-backend-pipeline` and instead require separately running ``` torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline ``` This highlights the step that lowers to the "torch backend contract" of cleaned up `torch` dialect ops is a critical step in the lowering. Going forward, that is the key load-bearing contract of the torch-mlir project, not the linalg-on-tensors backend contract. Recommended review order: - `TorchToTosa.cpp` / `TorchToTosa/basic.mlir` - `python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py` and the new `utils.py` file there. - `python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py` and `abc.py` in that directory for the TOSA backend e2e interface. - other misc mechanical changespull/359/head
parent
df12cc0c37
commit
0c5c84d63d
|
@ -61,6 +61,7 @@ jobs:
|
||||||
cd $GITHUB_WORKSPACE
|
cd $GITHUB_WORKSPACE
|
||||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||||
python -m e2e_testing.torchscript.main --config=refbackend -v
|
python -m e2e_testing.torchscript.main --config=refbackend -v
|
||||||
|
python -m e2e_testing.torchscript.main --config=tosa -v
|
||||||
|
|
||||||
# TODO: Only build packages in full Release mode.
|
# TODO: Only build packages in full Release mode.
|
||||||
# On the other hand, having assertions on isn't too bad of an idea at this
|
# On the other hand, having assertions on isn't too bad of an idea at this
|
||||||
|
|
|
@ -15,12 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
|
||||||
# Available test configs.
|
# Available test configs.
|
||||||
from torch_mlir_e2e_test.torchscript.configs import (
|
from torch_mlir_e2e_test.torchscript.configs import (
|
||||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||||
|
|
||||||
from .xfail_sets import XFAIL_SETS, COMMON_TORCH_MLIR_LOWERING_XFAILS
|
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
||||||
|
@ -35,7 +36,7 @@ from . import elementwise
|
||||||
from . import reduction
|
from . import reduction
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
parser.add_argument('-c', '--config',
|
parser.add_argument('-c', '--config',
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -43,6 +44,7 @@ def _get_argparse():
|
||||||
help=f'''
|
help=f'''
|
||||||
Meaning of options:
|
Meaning of options:
|
||||||
"refbackend": run through torch-mlir's RefBackend.
|
"refbackend": run through torch-mlir's RefBackend.
|
||||||
|
"tosa": run through torch-mlir's default TOSA backend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
"external": use an external backend, specified by the `--external-backend` option.
|
"external": use an external backend, specified by the `--external-backend` option.
|
||||||
|
@ -77,16 +79,27 @@ for more information on building these artifacts.
|
||||||
def main():
|
def main():
|
||||||
args = _get_argparse().parse_args()
|
args = _get_argparse().parse_args()
|
||||||
|
|
||||||
|
all_tests = list(GLOBAL_TEST_REGISTRY)
|
||||||
|
if args.serialized_test_dir:
|
||||||
|
for root, dirs, files in os.walk(args.serialized_test_dir):
|
||||||
|
for filename in files:
|
||||||
|
with open(os.path.join(root, filename), 'rb') as f:
|
||||||
|
all_tests.append(pickle.load(f).as_test())
|
||||||
|
all_test_unique_names = set(test.unique_name for test in all_tests)
|
||||||
|
|
||||||
# Find the selected config.
|
# Find the selected config.
|
||||||
if args.config == 'refbackend':
|
if args.config == 'refbackend':
|
||||||
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
|
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||||
xfail_set = XFAIL_SETS['refbackend']
|
xfail_set = REFBACKEND_XFAIL_SET
|
||||||
|
if args.config == 'tosa':
|
||||||
|
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||||
|
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||||
elif args.config == 'native_torch':
|
elif args.config == 'native_torch':
|
||||||
config = NativeTorchTestConfig()
|
config = NativeTorchTestConfig()
|
||||||
xfail_set = XFAIL_SETS['native_torch']
|
xfail_set = {}
|
||||||
elif args.config == 'torchscript':
|
elif args.config == 'torchscript':
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
xfail_set = XFAIL_SETS['torchscript']
|
xfail_set = {}
|
||||||
elif args.config == 'external':
|
elif args.config == 'external':
|
||||||
with open(args.external_config, 'r') as f:
|
with open(args.external_config, 'r') as f:
|
||||||
code = compile(f.read(), args.external_config, 'exec')
|
code = compile(f.read(), args.external_config, 'exec')
|
||||||
|
@ -106,13 +119,6 @@ def main():
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
all_tests = list(GLOBAL_TEST_REGISTRY)
|
|
||||||
if args.serialized_test_dir:
|
|
||||||
for root, dirs, files in os.walk(args.serialized_test_dir):
|
|
||||||
for filename in files:
|
|
||||||
with open(os.path.join(root, filename), 'rb') as f:
|
|
||||||
all_tests.append(pickle.load(f).as_test())
|
|
||||||
|
|
||||||
# Find the selected tests, and emit a diagnostic if none are found.
|
# Find the selected tests, and emit a diagnostic if none are found.
|
||||||
tests = [
|
tests = [
|
||||||
test for test in all_tests
|
test for test in all_tests
|
||||||
|
|
|
@ -10,17 +10,15 @@
|
||||||
# (this includes down into lower parts of the stack, where a side table
|
# (this includes down into lower parts of the stack, where a side table
|
||||||
# might be used to keep more elaborate sets of testing configurations).
|
# might be used to keep more elaborate sets of testing configurations).
|
||||||
|
|
||||||
XFAIL_SETS = {}
|
|
||||||
|
|
||||||
# Lists of tests that fail to even reach the backends.
|
# Lists of tests that fail to even reach the backends.
|
||||||
# These represent further work needed in torch-mlir to lower them properly
|
# These represent further work needed in torch-mlir to lower them properly
|
||||||
# to the backend contract.
|
# to the backend contract.
|
||||||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
'QuantizedMLP_basic',
|
"QuantizedMLP_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
XFAIL_SETS['refbackend'] = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
|
|
||||||
XFAIL_SETS['torchscript'] = {}
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
|
# and very few tests work yet.
|
||||||
XFAIL_SETS['native_torch'] = {}
|
TOSA_PASS_SET = {"ElementwiseUnaryModule_basic"}
|
||||||
|
|
|
@ -65,7 +65,7 @@ mlir_module.dump()
|
||||||
|
|
||||||
# Compile the torch MLIR and execute the compiled program
|
# Compile the torch MLIR and execute the compiled program
|
||||||
with mlir_module.context:
|
with mlir_module.context:
|
||||||
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline')
|
pm = PassManager.parse('torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
||||||
pm.run(mlir_module)
|
pm.run(mlir_module)
|
||||||
|
|
||||||
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
|
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
|
||||||
|
|
|
@ -153,7 +153,7 @@
|
||||||
" mb.import_module(scripted._c, class_annotator)\n",
|
" mb.import_module(scripted._c, class_annotator)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" ## Lower the MLIR from TorchScript to RefBackend, passing through linalg-on-tensors.\n",
|
" ## Lower the MLIR from TorchScript to RefBackend, passing through linalg-on-tensors.\n",
|
||||||
" pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n",
|
" pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n",
|
||||||
" pm.run(mb.module)\n",
|
" pm.run(mb.module)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" ## Invoke RefBackend to compile to compiled artifact form.\n",
|
" ## Invoke RefBackend to compile to compiled artifact form.\n",
|
||||||
|
|
|
@ -51,7 +51,7 @@ mlir_module.dump()
|
||||||
print(mlir_module.operation.verify())
|
print(mlir_module.operation.verify())
|
||||||
|
|
||||||
with mlir_module.context:
|
with mlir_module.context:
|
||||||
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
|
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
||||||
pm.run(mlir_module)
|
pm.run(mlir_module)
|
||||||
|
|
||||||
print("\n\nLOWERED MLIR")
|
print("\n\nLOWERED MLIR")
|
||||||
|
|
|
@ -109,7 +109,7 @@ mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||||
|
|
||||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||||
with mb.module.context:
|
with mb.module.context:
|
||||||
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
|
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
||||||
pm.run(mb.module)
|
pm.run(mb.module)
|
||||||
|
|
||||||
compiled = backend.compile(mb.module)
|
compiled = backend.compile(mb.module)
|
||||||
|
|
|
@ -105,4 +105,13 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
||||||
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
|
||||||
|
let summary = "Convert Torch ops to TOSA ops";
|
||||||
|
let description = [{
|
||||||
|
This pass assumes that TOSA ops are responsible for emitting error
|
||||||
|
guards in case of shape mismatches.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_PASSES
|
#endif // TORCHMLIR_CONVERSION_PASSES
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||||
|
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTosaPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
|
@ -33,13 +33,13 @@ struct TorchLoweringPipelineOptions
|
||||||
|
|
||||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||||
/// TorchScript import into the form expected by torch-verify-backend-contract.
|
/// TorchScript import into the form expected by torch-verify-backend-contract.
|
||||||
void createTorchScriptToTorchBackendPipeline(
|
void createTorchScriptModuleToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
||||||
/// with the torch and aten dialects and mutable arrays and converts it to
|
/// with the torch and aten dialects and mutable arrays and converts it to
|
||||||
/// the form required by torch-verify-backend-contract.
|
/// the form required by torch-verify-backend-contract.
|
||||||
void createGlobalizedModuleToTorchBackendPipeline(
|
void createTorchFunctionToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||||
|
|
|
@ -19,17 +19,15 @@ namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace TorchConversion {
|
namespace TorchConversion {
|
||||||
|
|
||||||
/// Creates a pipeline that lowers the object graph IR that is given by a
|
/// Creates a pipeline that lowers from the torch backend contract to the
|
||||||
/// TorchScript jit.ScriptModule into the form expected by
|
/// linalg-on-tensors backend contract.
|
||||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
void createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
|
||||||
OpPassManager &pm,
|
OpPassManager &pm,
|
||||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
/// Creates a pipeline that lowers the object graph IR that is given by a
|
/// Creates a pipeline that lowers from the torch backend contract to the
|
||||||
/// TorchScript jit.ScriptFunction into the form expected by
|
/// TOSA backend contract.
|
||||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
void createTorchBackendToTosaBackendPipeline(
|
||||||
void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
|
||||||
OpPassManager &pm,
|
OpPassManager &pm,
|
||||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
|
@ -44,6 +42,8 @@ createFinalizingBackendTypeConversionPass();
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createVerifyLinalgOnTensorsBackendContractPass();
|
createVerifyLinalgOnTensorsBackendContractPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
|
||||||
|
|
||||||
} // namespace TorchConversion
|
} // namespace TorchConversion
|
||||||
|
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
|
|
|
@ -58,4 +58,9 @@ def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
|
||||||
|
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||||
|
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
add_subdirectory(TorchToLinalg)
|
add_subdirectory(TorchToLinalg)
|
||||||
add_subdirectory(TorchToSCF)
|
add_subdirectory(TorchToSCF)
|
||||||
add_subdirectory(TorchToStd)
|
add_subdirectory(TorchToStd)
|
||||||
|
add_subdirectory(TorchToTosa)
|
||||||
|
|
||||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
# TODO: Automate this with add_torch_mlir_conversion_library.
|
||||||
#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS)
|
#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS)
|
||||||
|
@ -18,5 +19,6 @@ add_mlir_library(TorchMLIRConversionPasses
|
||||||
TorchMLIRTorchToLinalg
|
TorchMLIRTorchToLinalg
|
||||||
TorchMLIRTorchToSCF
|
TorchMLIRTorchToSCF
|
||||||
TorchMLIRTorchToStd
|
TorchMLIRTorchToStd
|
||||||
|
TorchMLIRTorchToTosa
|
||||||
#${torch_mlir_conversion_libs}
|
#${torch_mlir_conversion_libs}
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch;
|
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch;
|
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch;
|
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
add_mlir_conversion_library(TorchMLIRTorchToTosa
|
||||||
|
TorchToTosa.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRTosa
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRTorchToTosa)
|
|
@ -0,0 +1,76 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/Dialect/Traits.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenTanhOp : public OpConversionPattern<AtenTanhOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenTanhOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
AtenTanhOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()), adaptor.self());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// The pass
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTorchToTosa
|
||||||
|
: public ConvertTorchToTosaBase<ConvertTorchToTosa> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<tosa::TosaDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
target.addLegalDialect<tosa::TosaDialect>();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
target.addIllegalOp<AtenTanhOp>();
|
||||||
|
patterns.add<ConvertAtenTanhOp>(typeConverter, context);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
std::move(patterns))))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::torch::createConvertTorchToTosaPass() {
|
||||||
|
return std::make_unique<ConvertTorchToTosa>();
|
||||||
|
}
|
|
@ -23,16 +23,16 @@ namespace {
|
||||||
void mlir::torch::registerTorchPasses() {
|
void mlir::torch::registerTorchPasses() {
|
||||||
::registerPasses();
|
::registerPasses();
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-to-torch-backend-pipeline",
|
"torchscript-module-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||||
mlir::torch::Torch::createTorchScriptToTorchBackendPipeline);
|
mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torch-globalized-module-to-torch-backend-pipeline",
|
"torch-function-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering a globalized Torch program to Torch backend form.",
|
"Pipeline lowering a Torch function to Torch backend form.",
|
||||||
mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline);
|
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
// When we import TorchScript IR, we import their entire "compilation unit",
|
// When we import TorchScript IR, we import their entire "compilation unit",
|
||||||
// which can contain numerous functions unrelated to the current program,
|
// which can contain numerous functions unrelated to the current program,
|
||||||
|
@ -58,10 +58,10 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
||||||
// TODO: Improve shape inference.
|
// TODO: Improve shape inference.
|
||||||
pm.addPass(createInlinerPass());
|
pm.addPass(createInlinerPass());
|
||||||
|
|
||||||
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
createTorchFunctionToTorchBackendPipeline(pm, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
// General considerations: As a matter of bring-up, we are simultaneously
|
// General considerations: As a matter of bring-up, we are simultaneously
|
||||||
// building out the frontend pipeline and also co-developing the backend
|
// building out the frontend pipeline and also co-developing the backend
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch;
|
|
||||||
using namespace mlir::torch::TorchConversion;
|
using namespace mlir::torch::TorchConversion;
|
||||||
|
|
||||||
void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(
|
void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(
|
||||||
|
|
|
@ -3,6 +3,8 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
VerifyInvariantsBeforeBackendLowering.cpp
|
VerifyInvariantsBeforeBackendLowering.cpp
|
||||||
VerifyLinalgOnTensorsBackendContract.cpp
|
VerifyLinalgOnTensorsBackendContract.cpp
|
||||||
|
VerifyTosaBackendContract.cpp
|
||||||
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||||
|
|
|
@ -16,11 +16,11 @@
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
|
@ -34,16 +34,18 @@ namespace {
|
||||||
void mlir::torch::registerTorchConversionPasses() {
|
void mlir::torch::registerTorchConversionPasses() {
|
||||||
::registerPasses();
|
::registerPasses();
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-module-to-linalg-on-tensors-backend-pipeline",
|
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||||
"Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.",
|
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
||||||
TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline);
|
"contract.",
|
||||||
|
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-function-to-linalg-on-tensors-backend-pipeline",
|
"torch-backend-to-tosa-backend-pipeline",
|
||||||
"Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.",
|
"Pipeline lowering torch backend contract to TOSA backend "
|
||||||
TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline);
|
"contract.",
|
||||||
|
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||||
// Check some invariants to catch errors in a clear way.
|
// Check some invariants to catch errors in a clear way.
|
||||||
pm.addPass(
|
pm.addPass(
|
||||||
|
@ -80,20 +82,29 @@ static void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
|
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||||
|
// Check some invariants to catch errors in a clear way.
|
||||||
|
pm.addPass(
|
||||||
|
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||||
|
|
||||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
|
||||||
// backend contract.
|
|
||||||
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
if (options.optimize) {
|
||||||
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
// Clean up any non-canonical code introduced above..
|
||||||
}
|
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||||
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
}
|
||||||
|
|
||||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// backend contract.
|
// TOSA backend contract.
|
||||||
Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
pm.addNestedPass<FuncOp>(
|
||||||
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
|
// Verify that we have lowered to the form that TOSA backends
|
||||||
|
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
||||||
|
// correct form.
|
||||||
|
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::TorchConversion;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class VerifyTosaBackendContractPass
|
||||||
|
: public VerifyTosaBackendContractBase<VerifyTosaBackendContractPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
auto module = getOperation();
|
||||||
|
TypeConverter converter;
|
||||||
|
converter.addConversion([](TensorType type) -> Type {
|
||||||
|
if (BaseMemRefType::isValidElementType(type.getElementType()))
|
||||||
|
return type;
|
||||||
|
return nullptr;
|
||||||
|
});
|
||||||
|
|
||||||
|
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
||||||
|
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
|
// Structural operations.
|
||||||
|
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
|
||||||
|
// Basic scalar operations.
|
||||||
|
target.addLegalDialect<tosa::TosaDialect>();
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||||
|
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||||
|
// doesn't unnecessarily spew out the entire module.
|
||||||
|
emitError(module.getLoc())
|
||||||
|
<< "Module does not conform to the TOSA backend contract. "
|
||||||
|
"See dialect conversion legality information above.";
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() {
|
||||||
|
return std::make_unique<VerifyTosaBackendContractPass>();
|
||||||
|
}
|
|
@ -6,3 +6,4 @@
|
||||||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||||
from .native_torch import NativeTorchTestConfig
|
from .native_torch import NativeTorchTestConfig
|
||||||
from .torchscript import TorchScriptTestConfig
|
from .torchscript import TorchScriptTestConfig
|
||||||
|
from .tosa_backend import TosaBackendTestConfig
|
||||||
|
|
|
@ -12,47 +12,15 @@ import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
|
||||||
from torch_mlir.passmanager import PassManager
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||||
|
from .utils import (
|
||||||
|
recursively_convert_to_numpy,
|
||||||
|
recursively_convert_from_numpy,
|
||||||
|
convert_torchscript_module_to_torch_backend_contract_mlir,
|
||||||
|
run_pipeline_with_repro_report
|
||||||
|
)
|
||||||
|
|
||||||
def _recursively_convert_to_numpy(o: Any):
|
|
||||||
if isinstance(o, torch.Tensor):
|
|
||||||
return o.numpy()
|
|
||||||
if isinstance(o, tuple):
|
|
||||||
return tuple(_recursively_convert_to_numpy(x) for x in o)
|
|
||||||
if isinstance(o, list):
|
|
||||||
return [_recursively_convert_to_numpy(x) for x in o]
|
|
||||||
if isinstance(o, dict):
|
|
||||||
return {k: _recursively_convert_to_numpy(v) for k, v in o.items()}
|
|
||||||
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
|
||||||
if isinstance(o, str):
|
|
||||||
return o
|
|
||||||
if isinstance(o, float):
|
|
||||||
return o
|
|
||||||
if isinstance(o, int):
|
|
||||||
return o
|
|
||||||
raise Exception(f"Unexpected Python function input: {o}")
|
|
||||||
|
|
||||||
def _recursively_convert_from_numpy(o: Any):
|
|
||||||
if isinstance(o, np.ndarray):
|
|
||||||
return torch.from_numpy(o)
|
|
||||||
if isinstance(o, tuple):
|
|
||||||
return tuple(_recursively_convert_from_numpy(x) for x in o)
|
|
||||||
if isinstance(o, list):
|
|
||||||
return [_recursively_convert_from_numpy(x) for x in o]
|
|
||||||
if isinstance(o, dict):
|
|
||||||
return {k: _recursively_convert_from_numpy(v) for k, v in o.items()}
|
|
||||||
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
|
||||||
if isinstance(o, str):
|
|
||||||
return o
|
|
||||||
if isinstance(o, float):
|
|
||||||
return o
|
|
||||||
if isinstance(o, int):
|
|
||||||
return o
|
|
||||||
raise Exception(f"Unexpected Python function output: {o}")
|
|
||||||
|
|
||||||
class LinalgOnTensorsBackendTestConfig(TestConfig):
|
class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||||
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||||
|
@ -65,73 +33,28 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
mb = ModuleBuilder()
|
|
||||||
scripted = torch.jit.script(program)
|
|
||||||
class_annotator = ClassAnnotator()
|
|
||||||
|
|
||||||
extract_annotations(program, scripted, class_annotator)
|
module = convert_torchscript_module_to_torch_backend_contract_mlir(
|
||||||
|
program)
|
||||||
|
|
||||||
# TODO: Find a way to make each of these calls own its own
|
run_pipeline_with_repro_report(
|
||||||
# "debuggable error report" situation.
|
module,
|
||||||
try:
|
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||||
sys.stderr = StringIO()
|
"Lower Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||||
# Import the TorchScript module to MLIR
|
program.__class__.__name__)
|
||||||
mb.import_module(scripted._c, class_annotator)
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"""
|
|
||||||
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|
||||||
Exception:
|
|
||||||
{e}
|
|
||||||
Diagnostics:
|
|
||||||
{sys.stderr.getvalue()}
|
|
||||||
""") from None
|
|
||||||
finally:
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sys.stderr = StringIO()
|
sys.stderr = StringIO()
|
||||||
asm_for_error_report = mb.module.operation.get_asm(
|
asm_for_error_report = module.operation.get_asm(
|
||||||
large_elements_limit=10, enable_debug_info=True)
|
large_elements_limit=10, enable_debug_info=True)
|
||||||
pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline"
|
return self.backend.compile(module)
|
||||||
# Lower module in place to make it ready for compiler backends.
|
|
||||||
with mb.module.context:
|
|
||||||
pm = PassManager.parse(pipeline_str)
|
|
||||||
pm.run(mb.module)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: More robust.
|
|
||||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
|
||||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
|
||||||
# RAM fs, which increases worries about capacity).
|
|
||||||
# - don't have colliding filenames (hard to do without cluttering
|
|
||||||
# up /tmp)
|
|
||||||
# - if we do have have colliding filenames, writes should at least
|
|
||||||
# avoid being racy.
|
|
||||||
filename = os.path.join(tempfile.gettempdir(),
|
filename = os.path.join(tempfile.gettempdir(),
|
||||||
scripted.original_name + '.mlir')
|
program.__class__.__name__ + ".mlir")
|
||||||
with open(filename, 'w') as f:
|
with open(filename, 'w') as f:
|
||||||
f.write(asm_for_error_report)
|
f.write(asm_for_error_report)
|
||||||
raise Exception(f"""
|
raise Exception(f"""
|
||||||
torch-mlir TorchScript Object Graph IR -> linalg-on-tensors backend IR lowering failed with the following diagnostics:
|
Linalg-on-Tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||||
{sys.stderr.getvalue()}
|
|
||||||
|
|
||||||
Error can be reproduced with:
|
|
||||||
$ torch-mlir-opt -{pipeline_str} {filename}
|
|
||||||
""") from None
|
|
||||||
finally:
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
try:
|
|
||||||
sys.stderr = StringIO()
|
|
||||||
asm_for_error_report = mb.module.operation.get_asm(
|
|
||||||
large_elements_limit=10, enable_debug_info=True)
|
|
||||||
return self.backend.compile(mb.module)
|
|
||||||
except Exception as e:
|
|
||||||
filename = os.path.join(tempfile.gettempdir(),
|
|
||||||
scripted.original_name + '.mlir')
|
|
||||||
with open(filename, 'w') as f:
|
|
||||||
f.write(asm_for_error_report)
|
|
||||||
raise Exception(f"""
|
|
||||||
torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
|
||||||
## Exception:
|
## Exception:
|
||||||
{e}
|
{e}
|
||||||
|
|
||||||
|
@ -148,9 +71,9 @@ torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name
|
||||||
backend_module = self.backend.load(artifact)
|
backend_module = self.backend.load(artifact)
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
for item in trace:
|
for item in trace:
|
||||||
numpy_inputs = _recursively_convert_to_numpy(item.inputs)
|
numpy_inputs = recursively_convert_to_numpy(item.inputs)
|
||||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||||
output = _recursively_convert_from_numpy(outputs)
|
output = recursively_convert_from_numpy(outputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol,
|
||||||
inputs=item.inputs,
|
inputs=item.inputs,
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
from io import StringIO
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||||
|
from .utils import (
|
||||||
|
recursively_convert_to_numpy,
|
||||||
|
recursively_convert_from_numpy,
|
||||||
|
convert_torchscript_module_to_torch_backend_contract_mlir,
|
||||||
|
run_pipeline_with_repro_report
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TosaBackendTestConfig(TestConfig):
|
||||||
|
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||||
|
|
||||||
|
This class handles all the common lowering that torch-mlir does before
|
||||||
|
reaching the linalg-on-tensors abstraction level.
|
||||||
|
"""
|
||||||
|
def __init__(self, backend: TosaBackend):
|
||||||
|
super().__init__()
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
|
|
||||||
|
module = convert_torchscript_module_to_torch_backend_contract_mlir(
|
||||||
|
program)
|
||||||
|
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
module,
|
||||||
|
"torch-backend-to-tosa-backend-pipeline",
|
||||||
|
"Lower Torch Backend IR -> TOSA Backend IR",
|
||||||
|
program.__class__.__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.stderr = StringIO()
|
||||||
|
asm_for_error_report = module.operation.get_asm(
|
||||||
|
large_elements_limit=10, enable_debug_info=True)
|
||||||
|
return self.backend.compile(module)
|
||||||
|
except Exception as e:
|
||||||
|
filename = os.path.join(tempfile.gettempdir(),
|
||||||
|
program.__class__.__name__ + ".mlir")
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
f.write(asm_for_error_report)
|
||||||
|
raise Exception(f"""
|
||||||
|
TOSA Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||||
|
## Exception:
|
||||||
|
{e}
|
||||||
|
|
||||||
|
## Stderr:
|
||||||
|
{sys.stderr.getvalue()}
|
||||||
|
|
||||||
|
## Input IR has been saved in {filename}
|
||||||
|
""") from None
|
||||||
|
finally:
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
|
backend_module = self.backend.load(artifact)
|
||||||
|
result: Trace = []
|
||||||
|
for item in trace:
|
||||||
|
numpy_inputs = recursively_convert_to_numpy(item.inputs)
|
||||||
|
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||||
|
output = recursively_convert_from_numpy(outputs)
|
||||||
|
result.append(
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
output=output))
|
||||||
|
return result
|
|
@ -0,0 +1,129 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
from io import StringIO
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
||||||
|
from torch_mlir.passmanager import PassManager
|
||||||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
|
def recursively_convert_to_numpy(o: Any):
|
||||||
|
if isinstance(o, torch.Tensor):
|
||||||
|
return o.numpy()
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
return tuple(recursively_convert_to_numpy(x) for x in o)
|
||||||
|
if isinstance(o, list):
|
||||||
|
return [recursively_convert_to_numpy(x) for x in o]
|
||||||
|
if isinstance(o, dict):
|
||||||
|
return {k: recursively_convert_to_numpy(v) for k, v in o.items()}
|
||||||
|
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
||||||
|
if isinstance(o, str):
|
||||||
|
return o
|
||||||
|
if isinstance(o, float):
|
||||||
|
return o
|
||||||
|
if isinstance(o, int):
|
||||||
|
return o
|
||||||
|
raise Exception(f"Unexpected Python function input: {o}")
|
||||||
|
|
||||||
|
def recursively_convert_from_numpy(o: Any):
|
||||||
|
if isinstance(o, np.ndarray):
|
||||||
|
return torch.from_numpy(o)
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
return tuple(recursively_convert_from_numpy(x) for x in o)
|
||||||
|
if isinstance(o, list):
|
||||||
|
return [recursively_convert_from_numpy(x) for x in o]
|
||||||
|
if isinstance(o, dict):
|
||||||
|
return {k: recursively_convert_from_numpy(v) for k, v in o.items()}
|
||||||
|
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
||||||
|
if isinstance(o, str):
|
||||||
|
return o
|
||||||
|
if isinstance(o, float):
|
||||||
|
return o
|
||||||
|
if isinstance(o, int):
|
||||||
|
return o
|
||||||
|
raise Exception(f"Unexpected Python function output: {o}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_pipeline_with_repro_report(module,
|
||||||
|
pipeline: str,
|
||||||
|
description: str,
|
||||||
|
module_name: str):
|
||||||
|
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
|
||||||
|
try:
|
||||||
|
sys.stderr = StringIO()
|
||||||
|
asm_for_error_report = module.operation.get_asm(
|
||||||
|
large_elements_limit=10, enable_debug_info=True)
|
||||||
|
# Lower module in place to make it ready for compiler backends.
|
||||||
|
with module.context:
|
||||||
|
pm = PassManager.parse(pipeline)
|
||||||
|
pm.run(module)
|
||||||
|
except Exception as e:
|
||||||
|
# TODO: More robust.
|
||||||
|
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||||
|
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||||
|
# RAM fs, which increases worries about capacity).
|
||||||
|
# - don't have colliding filenames (hard to do without cluttering
|
||||||
|
# up /tmp)
|
||||||
|
# - if we do have have colliding filenames, writes should at least
|
||||||
|
# avoid being racy.
|
||||||
|
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
f.write(asm_for_error_report)
|
||||||
|
raise Exception(f"""
|
||||||
|
{description} failed with the following diagnostics:
|
||||||
|
{sys.stderr.getvalue()}
|
||||||
|
|
||||||
|
Error can be reproduced with:
|
||||||
|
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
|
||||||
|
""") from None
|
||||||
|
finally:
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
|
|
||||||
|
|
||||||
|
def convert_torchscript_module_to_torch_backend_contract_mlir(program: torch.nn.Module):
|
||||||
|
"""Perform common lowering from TorchScript to Torch MLIR
|
||||||
|
|
||||||
|
Returns an MLIR module that satisfies the Torch backend contract.
|
||||||
|
"""
|
||||||
|
mb = ModuleBuilder()
|
||||||
|
scripted = torch.jit.script(program)
|
||||||
|
class_annotator = ClassAnnotator()
|
||||||
|
|
||||||
|
extract_annotations(program, scripted, class_annotator)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Find a way to make each of these calls own its own
|
||||||
|
# "debuggable error report" situation.
|
||||||
|
try:
|
||||||
|
sys.stderr = StringIO()
|
||||||
|
# Import the TorchScript module to MLIR
|
||||||
|
mb.import_module(scripted._c, class_annotator)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"""
|
||||||
|
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||||
|
Exception:
|
||||||
|
{e}
|
||||||
|
Diagnostics:
|
||||||
|
{sys.stderr.getvalue()}
|
||||||
|
""") from None
|
||||||
|
finally:
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
|
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
mb.module,
|
||||||
|
"torchscript-module-to-torch-backend-pipeline",
|
||||||
|
"Lowering TorchScript Object Graph IR -> Torch Backend IR",
|
||||||
|
program.__class__.__name__)
|
||||||
|
|
||||||
|
return mb.module
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.ir import Module
|
||||||
|
|
||||||
|
# A type shared between the result of `TosaBackend.compile` and the
|
||||||
|
# input to `TosaBackend.load`. Each backend will likely have a
|
||||||
|
# different definition of this type.
|
||||||
|
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||||
|
|
||||||
|
# A wrapper around a backend-specific loaded program representation
|
||||||
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
|
# Torch modules into appropriate lower-level operations.
|
||||||
|
Invoker = TypeVar('Invoker')
|
||||||
|
|
||||||
|
|
||||||
|
class TosaBackend(abc.ABC):
|
||||||
|
"""The interface to an TOSA backend.
|
||||||
|
"""
|
||||||
|
@abc.abstractmethod
|
||||||
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
||||||
|
The module adheres to the TOSA backend contract
|
||||||
|
(see the VerifyTosaBackendContract pass).
|
||||||
|
|
||||||
|
The compiled artifact can be any type, but must be correctly
|
||||||
|
interpreted by the `load` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self, artifact: CompiledArtifact) -> Invoker:
|
||||||
|
"""Load the compiled artifact into a uniformly invokable form.
|
||||||
|
|
||||||
|
The compiled artifact is the result of a previous call to `compile`.
|
||||||
|
|
||||||
|
See the description of `Invoker` for the requirements on the returned
|
||||||
|
type.
|
||||||
|
"""
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
from torch_mlir.ir import *
|
||||||
|
from torch_mlir.passmanager import *
|
||||||
|
# Imported for side effects.
|
||||||
|
import torch_mlir.all_passes_registration
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
|
|
||||||
|
from .abc import TosaBackend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LinalgOnTensorsTosaBackend",
|
||||||
|
]
|
||||||
|
|
||||||
|
class LinalgOnTensorsTosaBackend(TosaBackend):
|
||||||
|
"""Main entry-point for the linalg-on-tensors based TOSA backend.
|
||||||
|
|
||||||
|
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||||
|
|
||||||
|
def compile(self, imported_module: Module):
|
||||||
|
"""Compiles an imported module that satisfied the TOSA backend contract.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imported_module: The MLIR module consisting of funcs in the TOSA
|
||||||
|
dialect.
|
||||||
|
Returns:
|
||||||
|
An opaque, backend specific compiled artifact object that can be
|
||||||
|
passed to `load`.
|
||||||
|
"""
|
||||||
|
# TODO: Error/repro reporting.
|
||||||
|
# We should store the program name as the symbol name of the MLIR
|
||||||
|
# module so we don't have to have access to the original program for it.
|
||||||
|
with imported_module.context:
|
||||||
|
pm = PassManager.parse("builtin.func(tosa-to-linalg-on-tensors)")
|
||||||
|
pm.run(imported_module)
|
||||||
|
return self.refbackend.compile(imported_module)
|
||||||
|
|
||||||
|
def load(self, module):
|
||||||
|
"""Loads a compiled artifact into the runtime."""
|
||||||
|
return self.refbackend.load(module)
|
|
@ -0,0 +1,12 @@
|
||||||
|
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.tanh$basic(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
|
@ -36,7 +36,7 @@ module {
|
||||||
//
|
//
|
||||||
// The reporting we inherit from dialect conversion is not precise.
|
// The reporting we inherit from dialect conversion is not precise.
|
||||||
// For example, here we want it to explicitly call out that
|
// For example, here we want it to explicitly call out that
|
||||||
// `tensor<?x!numpy.any_dtype>` is the problem here, which suggests
|
// `!torch.tensor` is the problem here, which suggests
|
||||||
// that type inference didn't succeed, or insufficient type information
|
// that type inference didn't succeed, or insufficient type information
|
||||||
// was available.
|
// was available.
|
||||||
//
|
//
|
||||||
|
@ -46,8 +46,8 @@ module {
|
||||||
|
|
||||||
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
|
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
|
||||||
module {
|
module {
|
||||||
func @disallowed(%arg0: tensor<?x!numpy.any_dtype>) -> tensor<?x!numpy.any_dtype> {
|
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
|
||||||
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||||
return %arg0 : tensor<?x!numpy.any_dtype>
|
return %arg0 : !torch.tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
// RUN: torch-mlir-opt -torch-verify-tosa-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: func @tanh
|
||||||
|
func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
%0 = "tosa.tanh"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
return %0 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Basic check of error reporting.
|
||||||
|
|
||||||
|
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
|
||||||
|
module {
|
||||||
|
func @disallowed() {
|
||||||
|
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
|
||||||
|
"unknown_dialect.unknown_op"() : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// TODO: Improve these errors to give more exact reporting.
|
||||||
|
//
|
||||||
|
// The reporting we inherit from dialect conversion is not precise.
|
||||||
|
// For example, here we want it to explicitly call out that
|
||||||
|
// `!torch.tensor` is the problem here, which suggests
|
||||||
|
// that type inference didn't succeed, or insufficient type information
|
||||||
|
// was available.
|
||||||
|
//
|
||||||
|
// Ultimately, the output of this pass needs to be conveyed to the user
|
||||||
|
// in an understandable way, such as suggesting a particular place where
|
||||||
|
// a shape annotation is needed.
|
||||||
|
|
||||||
|
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
|
||||||
|
module {
|
||||||
|
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
|
||||||
|
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||||
|
return %arg0 : !torch.tensor
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue