diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index b2541a855..80fe5f7ee 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -251,6 +251,11 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { if (ivalue.isModule()) { return importModule(ivalue.toModule()); } + if (ivalue.isNone()) { + MlirOperation operation = createMlirOperationAtEnd( + importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context)); + return mlirOperationGetResult(operation, 0); + } std::stringstream msg; msg << "Unsupported ivalue: " << ivalue; throw std::invalid_argument(msg.str()); diff --git a/frontends/pytorch/test/module_import/submodules-select.py b/frontends/pytorch/test/module_import/submodules-select.py index 664f50dc5..c6ff7ee1c 100644 --- a/frontends/pytorch/test/module_import/submodules-select.py +++ b/frontends/pytorch/test/module_import/submodules-select.py @@ -27,7 +27,7 @@ class TestModule(torch.nn.Module): # Modules with the same class can be selected between. # CHECK: %[[MOD:.*]] = scf.if s = self.s1 if b else self.s2 - # CHECK: %[[N:.*]] = torch.prim.GetAttr %4["n"] + # CHECK: %[[N:.*]] = torch.prim.GetAttr %5["n"] # CHECK: return %[[N]] return s.n