[tosa] Add TosaMakeBroadcastable pass to torch-to-tosa pipeline.

Fixes broken e2e test ElementwiseAddModule_basic

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/451/head
Suraj Sudhir 2021-11-23 22:25:59 -08:00 committed by Sean Silva
parent e6675a50d3
commit 1251c186b5
2 changed files with 5 additions and 0 deletions

View File

@ -29,6 +29,7 @@ TOSA_PASS_SET = {
"ElementwiseFloorModule_basic", "ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"TanhBackward_basic", "TanhBackward_basic",
"ElementwiseAddModule_basic",
"ReturnThreeTensorFloat32_basic", "ReturnThreeTensorFloat32_basic",
"AddCMulModule_basic", "AddCMulModule_basic",
"AddCDivModule_basic", "AddCDivModule_basic",

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
@ -21,6 +22,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::tosa;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration
@ -89,6 +91,8 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass()); pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
if (options.optimize) { if (options.optimize) {
// Clean up any non-canonical code introduced above.. // Clean up any non-canonical code introduced above..