diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index f0009a509..3dde793e5 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -29,6 +29,7 @@ TOSA_PASS_SET = { "ElementwiseFloorModule_basic", "ElementwiseLogModule_basic", "TanhBackward_basic", + "ElementwiseAddModule_basic", "ReturnThreeTensorFloat32_basic", "AddCMulModule_basic", "AddCDivModule_basic", diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 89607c93e..ba513e6ff 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" @@ -21,6 +22,7 @@ using namespace mlir; using namespace mlir::torch; +using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Pass registration @@ -89,6 +91,8 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); pm.addNestedPass(createConvertTorchToTosaPass()); + // Perform rank broadcasting so TosaToLinalg pass works + pm.addNestedPass(createTosaMakeBroadcastablePass()); if (options.optimize) { // Clean up any non-canonical code introduced above..