mirror of https://github.com/llvm/torch-mlir
[tosa] Add TosaMakeBroadcastable pass to torch-to-tosa pipeline.
Fixes broken e2e test ElementwiseAddModule_basic Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/451/head
parent
e6675a50d3
commit
1251c186b5
|
@ -29,6 +29,7 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseFloorModule_basic",
|
"ElementwiseFloorModule_basic",
|
||||||
"ElementwiseLogModule_basic",
|
"ElementwiseLogModule_basic",
|
||||||
"TanhBackward_basic",
|
"TanhBackward_basic",
|
||||||
|
"ElementwiseAddModule_basic",
|
||||||
"ReturnThreeTensorFloat32_basic",
|
"ReturnThreeTensorFloat32_basic",
|
||||||
"AddCMulModule_basic",
|
"AddCMulModule_basic",
|
||||||
"AddCDivModule_basic",
|
"AddCDivModule_basic",
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
@ -21,6 +22,7 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::tosa;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
|
@ -89,6 +91,8 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||||
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||||
|
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
|
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
|
||||||
|
// Perform rank broadcasting so TosaToLinalg pass works
|
||||||
|
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Clean up any non-canonical code introduced above..
|
// Clean up any non-canonical code introduced above..
|
||||||
|
|
Loading…
Reference in New Issue