From 1251c186b553003709c28ce64e7d2d9f808f39e1 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Tue, 23 Nov 2021 22:25:59 -0800 Subject: [PATCH] [tosa] Add TosaMakeBroadcastable pass to torch-to-tosa pipeline. Fixes broken e2e test ElementwiseAddModule_basic Signed-off-by: Suraj Sudhir --- e2e_testing/torchscript/xfail_sets.py | 1 + lib/Dialect/TorchConversion/Transforms/Passes.cpp | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index f0009a509..3dde793e5 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -29,6 +29,7 @@ TOSA_PASS_SET = { "ElementwiseFloorModule_basic", "ElementwiseLogModule_basic", "TanhBackward_basic", + "ElementwiseAddModule_basic", "ReturnThreeTensorFloat32_basic", "AddCMulModule_basic", "AddCDivModule_basic", diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 89607c93e..ba513e6ff 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" @@ -21,6 +22,7 @@ using namespace mlir; using namespace mlir::torch; +using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Pass registration @@ -89,6 +91,8 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); pm.addNestedPass(createConvertTorchToTosaPass()); + // Perform rank broadcasting so TosaToLinalg pass works + pm.addNestedPass(createTosaMakeBroadcastablePass()); if (options.optimize) { // Clean up any non-canonical code introduced above..