diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index a71f95f85..7f9ad4c2e 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -18,6 +18,8 @@ #include "mlir-c/Diagnostics.h" #include "npcomp-c/Types.h" +#include "caffe2/core/scope_guard.h" + using namespace torch_mlir; // Hashing functionality for IValue's. @@ -146,7 +148,8 @@ private: }; } // namespace -MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { +MlirValue +IValueImporter::importModule(torch::jit::Module currentModule) { // TODO: Can we do better? MlirLocation loc = mlirLocationUnknownGet(context); @@ -170,6 +173,10 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr)); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); + auto inserter = caffe2::MakeGuard([&]() { + mlirBlockInsertOwnedOperationBefore( + importBlock, mlirBlockGetTerminator(importBlock), nnModule); + }); if (!rootModuleName.has_value()) { rootModuleName = moduleTypeName; @@ -198,8 +205,6 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { } createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc); - mlirBlockInsertOwnedOperationBefore( - importBlock, mlirBlockGetTerminator(importBlock), nnModule); return mlirOperationGetResult(nnModule, 0); } diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py index 12745b64b..9e19d5f48 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py @@ -39,7 +39,10 @@ class RefBackendTestConfig(TestConfig): mb.import_module(scripted._c, class_annotator) except Exception as e: raise Exception(f""" -PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics: +PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with: +Exception: +{e} +Diagnostics: {sys.stderr.getvalue()} """) from None finally: