diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index fa0bf286b..4fd228318 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -61,6 +61,7 @@ jobs: cd $GITHUB_WORKSPACE export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" python -m e2e_testing.torchscript.main --config=refbackend -v + python -m e2e_testing.torchscript.main --config=tosa -v # TODO: Only build packages in full Release mode. # On the other hand, having assertions on isn't too bad of an idea at this diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index be7fbfcdd..0291e3ac2 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -15,12 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY # Available test configs. from torch_mlir_e2e_test.torchscript.configs import ( - LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig + LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import XFAIL_SETS, COMMON_TORCH_MLIR_LOWERING_XFAILS +from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS # Import tests to register them in the global registry. # Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking @@ -35,7 +36,7 @@ from . import elementwise from . import reduction def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'refbackend', 'external'] + config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -43,6 +44,7 @@ def _get_argparse(): help=f''' Meaning of options: "refbackend": run through torch-mlir's RefBackend. +"tosa": run through torch-mlir's default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "external": use an external backend, specified by the `--external-backend` option. @@ -77,16 +79,27 @@ for more information on building these artifacts. def main(): args = _get_argparse().parse_args() + all_tests = list(GLOBAL_TEST_REGISTRY) + if args.serialized_test_dir: + for root, dirs, files in os.walk(args.serialized_test_dir): + for filename in files: + with open(os.path.join(root, filename), 'rb') as f: + all_tests.append(pickle.load(f).as_test()) + all_test_unique_names = set(test.unique_name for test in all_tests) + # Find the selected config. if args.config == 'refbackend': config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) - xfail_set = XFAIL_SETS['refbackend'] + xfail_set = REFBACKEND_XFAIL_SET + if args.config == 'tosa': + config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) + xfail_set = all_test_unique_names - TOSA_PASS_SET elif args.config == 'native_torch': config = NativeTorchTestConfig() - xfail_set = XFAIL_SETS['native_torch'] + xfail_set = {} elif args.config == 'torchscript': config = TorchScriptTestConfig() - xfail_set = XFAIL_SETS['torchscript'] + xfail_set = {} elif args.config == 'external': with open(args.external_config, 'r') as f: code = compile(f.read(), args.external_config, 'exec') @@ -106,13 +119,6 @@ def main(): ) sys.exit(1) - all_tests = list(GLOBAL_TEST_REGISTRY) - if args.serialized_test_dir: - for root, dirs, files in os.walk(args.serialized_test_dir): - for filename in files: - with open(os.path.join(root, filename), 'rb') as f: - all_tests.append(pickle.load(f).as_test()) - # Find the selected tests, and emit a diagnostic if none are found. tests = [ test for test in all_tests diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 5f8625f5f..c6bc24f5c 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -10,17 +10,15 @@ # (this includes down into lower parts of the stack, where a side table # might be used to keep more elaborate sets of testing configurations). -XFAIL_SETS = {} - # Lists of tests that fail to even reach the backends. # These represent further work needed in torch-mlir to lower them properly # to the backend contract. COMMON_TORCH_MLIR_LOWERING_XFAILS = { - 'QuantizedMLP_basic', + "QuantizedMLP_basic", } -XFAIL_SETS['refbackend'] = COMMON_TORCH_MLIR_LOWERING_XFAILS +REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS -XFAIL_SETS['torchscript'] = {} - -XFAIL_SETS['native_torch'] = {} +# Write the TOSA set as a "passing" set as it is very early in development +# and very few tests work yet. +TOSA_PASS_SET = {"ElementwiseUnaryModule_basic"} diff --git a/examples/lazytensor_tanh.py b/examples/lazytensor_tanh.py index c0b10a976..9210e3938 100644 --- a/examples/lazytensor_tanh.py +++ b/examples/lazytensor_tanh.py @@ -65,7 +65,7 @@ mlir_module.dump() # Compile the torch MLIR and execute the compiled program with mlir_module.context: - pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline') + pm = PassManager.parse('torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline') pm.run(mlir_module) print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE") diff --git a/examples/resnet_inference.ipynb b/examples/resnet_inference.ipynb index c68294851..f6ca6bf12 100644 --- a/examples/resnet_inference.ipynb +++ b/examples/resnet_inference.ipynb @@ -153,7 +153,7 @@ " mb.import_module(scripted._c, class_annotator)\n", "\n", " ## Lower the MLIR from TorchScript to RefBackend, passing through linalg-on-tensors.\n", - " pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n", + " pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n", " pm.run(mb.module)\n", "\n", " ## Invoke RefBackend to compile to compiled artifact form.\n", diff --git a/examples/torchfx_add_tanh_sigmoid.py b/examples/torchfx_add_tanh_sigmoid.py index 9c950c284..e15832713 100644 --- a/examples/torchfx_add_tanh_sigmoid.py +++ b/examples/torchfx_add_tanh_sigmoid.py @@ -51,7 +51,7 @@ mlir_module.dump() print(mlir_module.operation.verify()) with mlir_module.context: - pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline') + pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline') pm.run(mlir_module) print("\n\nLOWERED MLIR") diff --git a/examples/torchscript_resnet18_e2e.py b/examples/torchscript_resnet18_e2e.py index 38bef1368..8752be057 100644 --- a/examples/torchscript_resnet18_e2e.py +++ b/examples/torchscript_resnet18_e2e.py @@ -109,7 +109,7 @@ mb.import_module(recursivescriptmodule._c, class_annotator) backend = refbackend.RefBackendLinalgOnTensorsBackend() with mb.module.context: - pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline') + pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline') pm.run(mb.module) compiled = backend.compile(mb.module) diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 61d3ceb03..77499a33a 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -105,4 +105,13 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> { let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; } +def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> { + let summary = "Convert Torch ops to TOSA ops"; + let description = [{ + This pass assumes that TOSA ops are responsible for emitting error + guards in case of shape mismatches. + }]; + let constructor = "mlir::torch::createConvertTorchToTosaPass()"; +} + #endif // TORCHMLIR_CONVERSION_PASSES diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h new file mode 100644 index 000000000..41a53a696 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -0,0 +1,22 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H +#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> createConvertTorchToTosaPass(); +} +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 606a8701c..34fd64170 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -33,13 +33,13 @@ struct TorchLoweringPipelineOptions /// Creates a pipeline that lowers the object graph IR that is produced by /// TorchScript import into the form expected by torch-verify-backend-contract. -void createTorchScriptToTorchBackendPipeline( +void createTorchScriptModuleToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); /// Creates a pipeline that lowers a flat list of funcs and global slots /// with the torch and aten dialects and mutable arrays and converts it to /// the form required by torch-verify-backend-contract. -void createGlobalizedModuleToTorchBackendPipeline( +void createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); std::unique_ptr> createAdjustCallingConventionsPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index 14c93ca21..14e589836 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -19,17 +19,15 @@ namespace mlir { namespace torch { namespace TorchConversion { -/// Creates a pipeline that lowers the object graph IR that is given by a -/// TorchScript jit.ScriptModule into the form expected by -/// torch-verify-linalg-on-tensors-verify-backend-contract. -void createTorchScriptModuleToLinalgOnTensorsBackendPipeline( +/// Creates a pipeline that lowers from the torch backend contract to the +/// linalg-on-tensors backend contract. +void createTorchBackendToLinalgOnTensorsBackendPipeline( OpPassManager &pm, const torch::Torch::TorchLoweringPipelineOptions &options); -/// Creates a pipeline that lowers the object graph IR that is given by a -/// TorchScript jit.ScriptFunction into the form expected by -/// torch-verify-linalg-on-tensors-verify-backend-contract. -void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline( +/// Creates a pipeline that lowers from the torch backend contract to the +/// TOSA backend contract. +void createTorchBackendToTosaBackendPipeline( OpPassManager &pm, const torch::Torch::TorchLoweringPipelineOptions &options); @@ -44,6 +42,8 @@ createFinalizingBackendTypeConversionPass(); std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); +std::unique_ptr> createVerifyTosaBackendContractPass(); + } // namespace TorchConversion /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 189d3a337..8b7df27bc 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -58,4 +58,9 @@ def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors- let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; } +def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the linalg-on-tensors backend contract"; + let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; +} + #endif // TORCHMLIR_TORCHCONVERSION_PASSES diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 9445701b6..9f9c6c653 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToStd) +add_subdirectory(TorchToTosa) # TODO: Automate this with add_torch_mlir_conversion_library. #get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS) @@ -18,5 +19,6 @@ add_mlir_library(TorchMLIRConversionPasses TorchMLIRTorchToLinalg TorchMLIRTorchToSCF TorchMLIRTorchToStd + TorchMLIRTorchToTosa #${torch_mlir_conversion_libs} ) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index c286fd007..f52de62d6 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -12,6 +12,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" //===----------------------------------------------------------------------===// // Pass registration diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 31a8ff75c..62e697889 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -22,7 +22,6 @@ using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch; using namespace mlir::torch::Torch; // ----------------------------------------------------------------------------- diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 4496c1502..71aad51a4 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -19,7 +19,6 @@ using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 271c0f101..8e3c67288 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -20,7 +20,6 @@ using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch; using namespace mlir::torch::Torch; // ----------------------------------------------------------------------------- diff --git a/lib/Conversion/TorchToTosa/CMakeLists.txt b/lib/Conversion/TorchToTosa/CMakeLists.txt new file mode 100644 index 000000000..660b8c36d --- /dev/null +++ b/lib/Conversion/TorchToTosa/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(TorchMLIRTorchToTosa + TorchToTosa.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTosa + TorchMLIRTorchDialect +) + +torch_mlir_target_includes(TorchMLIRTorchToTosa) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp new file mode 100644 index 000000000..6c4f2fdb5 --- /dev/null +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class ConvertAtenTanhOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenTanhOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + AtenTanhOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self()); + return success(); + } +}; +} // namespace + +// ----------------------------------------------------------------------------- +// The pass +// ----------------------------------------------------------------------------- + +namespace { +class ConvertTorchToTosa + : public ConvertTorchToTosaBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchToTosaPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index c111dc459..a7891cee6 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -23,16 +23,16 @@ namespace { void mlir::torch::registerTorchPasses() { ::registerPasses(); mlir::PassPipelineRegistration( - "torchscript-to-torch-backend-pipeline", + "torchscript-module-to-torch-backend-pipeline", "Pipeline lowering TorchScript object graph IR to Torch backend form.", - mlir::torch::Torch::createTorchScriptToTorchBackendPipeline); + mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline); mlir::PassPipelineRegistration( - "torch-globalized-module-to-torch-backend-pipeline", - "Pipeline lowering a globalized Torch program to Torch backend form.", - mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline); + "torch-function-to-torch-backend-pipeline", + "Pipeline lowering a Torch function to Torch backend form.", + mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline); } -void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline( +void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { // When we import TorchScript IR, we import their entire "compilation unit", // which can contain numerous functions unrelated to the current program, @@ -58,10 +58,10 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline( // TODO: Improve shape inference. pm.addPass(createInlinerPass()); - createGlobalizedModuleToTorchBackendPipeline(pm, options); + createTorchFunctionToTorchBackendPipeline(pm, options); } -void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline( +void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { // General considerations: As a matter of bring-up, we are simultaneously // building out the frontend pipeline and also co-developing the backend diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index c5c6cf067..fe0164b08 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -21,7 +21,6 @@ using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch; using namespace mlir::torch::TorchConversion; void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects( diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 2b31bd4bb..cb43655a1 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -3,6 +3,8 @@ add_mlir_library(TorchMLIRTorchConversionPasses Passes.cpp VerifyInvariantsBeforeBackendLowering.cpp VerifyLinalgOnTensorsBackendContract.cpp + VerifyTosaBackendContract.cpp + ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index af8d7bea7..89607c93e 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -16,11 +16,11 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch; //===----------------------------------------------------------------------===// // Pass registration @@ -34,16 +34,18 @@ namespace { void mlir::torch::registerTorchConversionPasses() { ::registerPasses(); mlir::PassPipelineRegistration( - "torchscript-module-to-linalg-on-tensors-backend-pipeline", - "Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.", - TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline); + "torch-backend-to-linalg-on-tensors-backend-pipeline", + "Pipeline lowering torch backend contract to linalg-on-tensors backend " + "contract.", + TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); mlir::PassPipelineRegistration( - "torchscript-function-to-linalg-on-tensors-backend-pipeline", - "Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.", - TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline); + "torch-backend-to-tosa-backend-pipeline", + "Pipeline lowering torch backend contract to TOSA backend " + "contract.", + TorchConversion::createTorchBackendToTosaBackendPipeline); } -static void createTorchBackendToLinalgOnTensorsBackendPipeline( +void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { // Check some invariants to catch errors in a clear way. pm.addPass( @@ -80,20 +82,29 @@ static void createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()); } -void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline( +void TorchConversion::createTorchBackendToTosaBackendPipeline( OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { + // Check some invariants to catch errors in a clear way. + pm.addPass( + TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); - // Conversion to the linalg-on-tensors backend contract starts from the Torch - // backend contract. - Torch::createTorchScriptToTorchBackendPipeline(pm, options); - createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options); -} - -void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline( - OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { - - // Conversion to the linalg-on-tensors backend contract starts from the Torch - // backend contract. - Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options); - createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options); + pm.addNestedPass(createConvertTorchToTosaPass()); + + if (options.optimize) { + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + } + + // Finish the type conversion from `torch` types to the types of the + // TOSA backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Verify that we have lowered to the form that TOSA backends + // expect. This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); } diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp new file mode 100644 index 000000000..75b03bbea --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::TorchConversion; + +namespace { +class VerifyTosaBackendContractPass + : public VerifyTosaBackendContractBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto module = getOperation(); + TypeConverter converter; + converter.addConversion([](TensorType type) -> Type { + if (BaseMemRefType::isValidElementType(type.getElementType())) + return type; + return nullptr; + }); + + auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); }; + + ConversionTarget target(*context); + + // Structural operations. + target.addDynamicallyLegalOp(opHasLegalTypes); + // Basic scalar operations. + target.addLegalDialect(); + + RewritePatternSet patterns(context); + if (failed(applyFullConversion(module, target, std::move(patterns)))) { + // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics + // doesn't unnecessarily spew out the entire module. + emitError(module.getLoc()) + << "Module does not conform to the TOSA backend contract. " + "See dialect conversion legality information above."; + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() { + return std::make_unique(); +} diff --git a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py index 0bfababd9..b201dc923 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py @@ -6,3 +6,4 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig +from .tosa_backend import TosaBackendTestConfig diff --git a/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py b/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py index 95b6083a3..cc3375f47 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py @@ -12,47 +12,15 @@ import tempfile import numpy as np import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder -from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations -from torch_mlir.passmanager import PassManager from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +from .utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, + convert_torchscript_module_to_torch_backend_contract_mlir, + run_pipeline_with_repro_report +) -def _recursively_convert_to_numpy(o: Any): - if isinstance(o, torch.Tensor): - return o.numpy() - if isinstance(o, tuple): - return tuple(_recursively_convert_to_numpy(x) for x in o) - if isinstance(o, list): - return [_recursively_convert_to_numpy(x) for x in o] - if isinstance(o, dict): - return {k: _recursively_convert_to_numpy(v) for k, v in o.items()} - # No-op cases. Explicitly enumerated to avoid things sneaking through. - if isinstance(o, str): - return o - if isinstance(o, float): - return o - if isinstance(o, int): - return o - raise Exception(f"Unexpected Python function input: {o}") - -def _recursively_convert_from_numpy(o: Any): - if isinstance(o, np.ndarray): - return torch.from_numpy(o) - if isinstance(o, tuple): - return tuple(_recursively_convert_from_numpy(x) for x in o) - if isinstance(o, list): - return [_recursively_convert_from_numpy(x) for x in o] - if isinstance(o, dict): - return {k: _recursively_convert_from_numpy(v) for k, v in o.items()} - # No-op cases. Explicitly enumerated to avoid things sneaking through. - if isinstance(o, str): - return o - if isinstance(o, float): - return o - if isinstance(o, int): - return o - raise Exception(f"Unexpected Python function output: {o}") class LinalgOnTensorsBackendTestConfig(TestConfig): """Base class for TestConfig's that are implemented with linalg-on-tensors. @@ -65,73 +33,28 @@ class LinalgOnTensorsBackendTestConfig(TestConfig): self.backend = backend def compile(self, program: torch.nn.Module) -> Any: - mb = ModuleBuilder() - scripted = torch.jit.script(program) - class_annotator = ClassAnnotator() - extract_annotations(program, scripted, class_annotator) + module = convert_torchscript_module_to_torch_backend_contract_mlir( + program) - # TODO: Find a way to make each of these calls own its own - # "debuggable error report" situation. - try: - sys.stderr = StringIO() - # Import the TorchScript module to MLIR - mb.import_module(scripted._c, class_annotator) - except Exception as e: - raise Exception(f""" -PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: -Exception: -{e} -Diagnostics: -{sys.stderr.getvalue()} -""") from None - finally: - sys.stderr = sys.__stderr__ + run_pipeline_with_repro_report( + module, + "torch-backend-to-linalg-on-tensors-backend-pipeline", + "Lower Torch Backend IR -> Linalg-on-Tensors Backend IR", + program.__class__.__name__) try: sys.stderr = StringIO() - asm_for_error_report = mb.module.operation.get_asm( + asm_for_error_report = module.operation.get_asm( large_elements_limit=10, enable_debug_info=True) - pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline" - # Lower module in place to make it ready for compiler backends. - with mb.module.context: - pm = PassManager.parse(pipeline_str) - pm.run(mb.module) + return self.backend.compile(module) except Exception as e: - # TODO: More robust. - # - don't arbitrarily clutter up /tmp. When a test suite has many - # tests, this can be a big disk cost (also, /tmp/ is frequently a - # RAM fs, which increases worries about capacity). - # - don't have colliding filenames (hard to do without cluttering - # up /tmp) - # - if we do have have colliding filenames, writes should at least - # avoid being racy. filename = os.path.join(tempfile.gettempdir(), - scripted.original_name + '.mlir') + program.__class__.__name__ + ".mlir") with open(filename, 'w') as f: f.write(asm_for_error_report) raise Exception(f""" -torch-mlir TorchScript Object Graph IR -> linalg-on-tensors backend IR lowering failed with the following diagnostics: -{sys.stderr.getvalue()} - -Error can be reproduced with: -$ torch-mlir-opt -{pipeline_str} {filename} -""") from None - finally: - sys.stderr = sys.__stderr__ - - try: - sys.stderr = StringIO() - asm_for_error_report = mb.module.operation.get_asm( - large_elements_limit=10, enable_debug_info=True) - return self.backend.compile(mb.module) - except Exception as e: - filename = os.path.join(tempfile.gettempdir(), - scripted.original_name + '.mlir') - with open(filename, 'w') as f: - f.write(asm_for_error_report) - raise Exception(f""" -torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics: +Linalg-on-Tensors Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics: ## Exception: {e} @@ -148,9 +71,9 @@ torch-mlir linalg-on-tensors Backend lowering for {self.backend.__class__.__name backend_module = self.backend.load(artifact) result: Trace = [] for item in trace: - numpy_inputs = _recursively_convert_to_numpy(item.inputs) + numpy_inputs = recursively_convert_to_numpy(item.inputs) outputs = getattr(backend_module, item.symbol)(*numpy_inputs) - output = _recursively_convert_from_numpy(outputs) + output = recursively_convert_from_numpy(outputs) result.append( TraceItem(symbol=item.symbol, inputs=item.inputs, diff --git a/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py b/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py new file mode 100644 index 000000000..d4c57e9e8 --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py @@ -0,0 +1,81 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import sys +from typing import Any +from io import StringIO +import os +import tempfile + +import numpy as np +import torch + +from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +from .utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, + convert_torchscript_module_to_torch_backend_contract_mlir, + run_pipeline_with_repro_report +) + + +class TosaBackendTestConfig(TestConfig): + """Base class for TestConfig's that are implemented with linalg-on-tensors. + + This class handles all the common lowering that torch-mlir does before + reaching the linalg-on-tensors abstraction level. + """ + def __init__(self, backend: TosaBackend): + super().__init__() + self.backend = backend + + def compile(self, program: torch.nn.Module) -> Any: + + module = convert_torchscript_module_to_torch_backend_contract_mlir( + program) + + run_pipeline_with_repro_report( + module, + "torch-backend-to-tosa-backend-pipeline", + "Lower Torch Backend IR -> TOSA Backend IR", + program.__class__.__name__) + + try: + sys.stderr = StringIO() + asm_for_error_report = module.operation.get_asm( + large_elements_limit=10, enable_debug_info=True) + return self.backend.compile(module) + except Exception as e: + filename = os.path.join(tempfile.gettempdir(), + program.__class__.__name__ + ".mlir") + with open(filename, 'w') as f: + f.write(asm_for_error_report) + raise Exception(f""" +TOSA Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics: +## Exception: +{e} + +## Stderr: +{sys.stderr.getvalue()} + +## Input IR has been saved in {filename} +""") from None + finally: + sys.stderr = sys.__stderr__ + + + def run(self, artifact: Any, trace: Trace) -> Trace: + backend_module = self.backend.load(artifact) + result: Trace = [] + for item in trace: + numpy_inputs = recursively_convert_to_numpy(item.inputs) + outputs = getattr(backend_module, item.symbol)(*numpy_inputs) + output = recursively_convert_from_numpy(outputs) + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) + return result diff --git a/python/torch_mlir_e2e_test/torchscript/configs/utils.py b/python/torch_mlir_e2e_test/torchscript/configs/utils.py new file mode 100644 index 000000000..1a7237528 --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/configs/utils.py @@ -0,0 +1,129 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import sys +from typing import Any +from io import StringIO +import os +import tempfile + +import numpy as np +import torch + +from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations +from torch_mlir.passmanager import PassManager +from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem + +def recursively_convert_to_numpy(o: Any): + if isinstance(o, torch.Tensor): + return o.numpy() + if isinstance(o, tuple): + return tuple(recursively_convert_to_numpy(x) for x in o) + if isinstance(o, list): + return [recursively_convert_to_numpy(x) for x in o] + if isinstance(o, dict): + return {k: recursively_convert_to_numpy(v) for k, v in o.items()} + # No-op cases. Explicitly enumerated to avoid things sneaking through. + if isinstance(o, str): + return o + if isinstance(o, float): + return o + if isinstance(o, int): + return o + raise Exception(f"Unexpected Python function input: {o}") + +def recursively_convert_from_numpy(o: Any): + if isinstance(o, np.ndarray): + return torch.from_numpy(o) + if isinstance(o, tuple): + return tuple(recursively_convert_from_numpy(x) for x in o) + if isinstance(o, list): + return [recursively_convert_from_numpy(x) for x in o] + if isinstance(o, dict): + return {k: recursively_convert_from_numpy(v) for k, v in o.items()} + # No-op cases. Explicitly enumerated to avoid things sneaking through. + if isinstance(o, str): + return o + if isinstance(o, float): + return o + if isinstance(o, int): + return o + raise Exception(f"Unexpected Python function output: {o}") + + +def run_pipeline_with_repro_report(module, + pipeline: str, + description: str, + module_name: str): + """Runs `pipeline` on `module`, with a nice repro report if it fails.""" + try: + sys.stderr = StringIO() + asm_for_error_report = module.operation.get_asm( + large_elements_limit=10, enable_debug_info=True) + # Lower module in place to make it ready for compiler backends. + with module.context: + pm = PassManager.parse(pipeline) + pm.run(module) + except Exception as e: + # TODO: More robust. + # - don't arbitrarily clutter up /tmp. When a test suite has many + # tests, this can be a big disk cost (also, /tmp/ is frequently a + # RAM fs, which increases worries about capacity). + # - don't have colliding filenames (hard to do without cluttering + # up /tmp) + # - if we do have have colliding filenames, writes should at least + # avoid being racy. + filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") + with open(filename, 'w') as f: + f.write(asm_for_error_report) + raise Exception(f""" +{description} failed with the following diagnostics: +{sys.stderr.getvalue()} + +Error can be reproduced with: +$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} +""") from None + finally: + sys.stderr = sys.__stderr__ + + +def convert_torchscript_module_to_torch_backend_contract_mlir(program: torch.nn.Module): + """Perform common lowering from TorchScript to Torch MLIR + + Returns an MLIR module that satisfies the Torch backend contract. + """ + mb = ModuleBuilder() + scripted = torch.jit.script(program) + class_annotator = ClassAnnotator() + + extract_annotations(program, scripted, class_annotator) + + + # TODO: Find a way to make each of these calls own its own + # "debuggable error report" situation. + try: + sys.stderr = StringIO() + # Import the TorchScript module to MLIR + mb.import_module(scripted._c, class_annotator) + except Exception as e: + raise Exception(f""" +PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: +Exception: +{e} +Diagnostics: +{sys.stderr.getvalue()} +""") from None + finally: + sys.stderr = sys.__stderr__ + + run_pipeline_with_repro_report( + mb.module, + "torchscript-module-to-torch-backend-pipeline", + "Lowering TorchScript Object Graph IR -> Torch Backend IR", + program.__class__.__name__) + + return mb.module diff --git a/python/torch_mlir_e2e_test/tosa_backends/__init__.py b/python/torch_mlir_e2e_test/tosa_backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/torch_mlir_e2e_test/tosa_backends/abc.py b/python/torch_mlir_e2e_test/tosa_backends/abc.py new file mode 100644 index 000000000..8d1e251dc --- /dev/null +++ b/python/torch_mlir_e2e_test/tosa_backends/abc.py @@ -0,0 +1,46 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import abc +from typing import TypeVar + +import torch + +from torch_mlir.ir import Module + +# A type shared between the result of `TosaBackend.compile` and the +# input to `TosaBackend.load`. Each backend will likely have a +# different definition of this type. +CompiledArtifact = TypeVar('CompiledArtifact') + +# A wrapper around a backend-specific loaded program representation +# that uniformly translates the `x.method(...)` interface expected of +# Torch modules into appropriate lower-level operations. +Invoker = TypeVar('Invoker') + + +class TosaBackend(abc.ABC): + """The interface to an TOSA backend. + """ + @abc.abstractmethod + def compile(self, module: Module) -> CompiledArtifact: + """Compile the provided MLIR module into a compiled artifact. + + The module adheres to the TOSA backend contract + (see the VerifyTosaBackendContract pass). + + The compiled artifact can be any type, but must be correctly + interpreted by the `load` method. + """ + + @abc.abstractmethod + def load(self, artifact: CompiledArtifact) -> Invoker: + """Load the compiled artifact into a uniformly invokable form. + + The compiled artifact is the result of a previous call to `compile`. + + See the description of `Invoker` for the requirements on the returned + type. + """ diff --git a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py new file mode 100644 index 000000000..493c7c457 --- /dev/null +++ b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -0,0 +1,48 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +# Imported for side effects. +import torch_mlir.all_passes_registration + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import TosaBackend + +__all__ = [ + "LinalgOnTensorsTosaBackend", +] + +class LinalgOnTensorsTosaBackend(TosaBackend): + """Main entry-point for the linalg-on-tensors based TOSA backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the TOSA backend contract. + + Args: + imported_module: The MLIR module consisting of funcs in the TOSA + dialect. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + # TODO: Error/repro reporting. + # We should store the program name as the symbol name of the MLIR + # module so we don't have to have access to the original program for it. + with imported_module.context: + pm = PassManager.parse("builtin.func(tosa-to-linalg-on-tensors)") + pm.run(imported_module) + return self.refbackend.compile(imported_module) + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir new file mode 100644 index 000000000..25e420dcb --- /dev/null +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -0,0 +1,12 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @torch.aten.tanh$basic( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Dialect/TorchConversion/verify-linalg-on-tensors-backend-contract.mlir b/test/Dialect/TorchConversion/verify-linalg-on-tensors-backend-contract.mlir index 2cae004c7..b5254e9c5 100644 --- a/test/Dialect/TorchConversion/verify-linalg-on-tensors-backend-contract.mlir +++ b/test/Dialect/TorchConversion/verify-linalg-on-tensors-backend-contract.mlir @@ -36,7 +36,7 @@ module { // // The reporting we inherit from dialect conversion is not precise. // For example, here we want it to explicitly call out that -// `tensor` is the problem here, which suggests +// `!torch.tensor` is the problem here, which suggests // that type inference didn't succeed, or insufficient type information // was available. // @@ -46,8 +46,8 @@ module { // expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}} module { - func @disallowed(%arg0: tensor) -> tensor { + func @disallowed(%arg0: !torch.tensor) -> !torch.tensor { // expected-error@+1 {{failed to legalize operation 'std.return'}} - return %arg0 : tensor + return %arg0 : !torch.tensor } } diff --git a/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir b/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir new file mode 100644 index 000000000..93198dd38 --- /dev/null +++ b/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir @@ -0,0 +1,42 @@ +// RUN: torch-mlir-opt -torch-verify-tosa-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s + +// CHECK: func @tanh +func @tanh(%arg0: tensor) -> tensor { + %0 = "tosa.tanh"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// Basic check of error reporting. + +// expected-error@+1 {{Module does not conform to the TOSA backend contract.}} +module { + func @disallowed() { + // expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}} + "unknown_dialect.unknown_op"() : () -> () + return + } +} + +// ----- + +// TODO: Improve these errors to give more exact reporting. +// +// The reporting we inherit from dialect conversion is not precise. +// For example, here we want it to explicitly call out that +// `!torch.tensor` is the problem here, which suggests +// that type inference didn't succeed, or insufficient type information +// was available. +// +// Ultimately, the output of this pass needs to be conveyed to the user +// in an understandable way, such as suggesting a particular place where +// a shape annotation is needed. + +// expected-error@+1 {{Module does not conform to the TOSA backend contract.}} +module { + func @disallowed(%arg0: !torch.tensor) -> !torch.tensor { + // expected-error@+1 {{failed to legalize operation 'std.return'}} + return %arg0 : !torch.tensor + } +}