Add NoneType support for ivalue_importer

PyTorch added a Global variable `_is_full_backward_hook` recently.

See https://github.com/pytorch/pytorch/pull/46163

Signed-off-by: Bairen Yi <yibairen.byron@bytedance.com>
pull/161/head
Bairen Yi 2021-02-18 21:56:15 +08:00 committed by Sean Silva
parent a38b7b72b2
commit 99d1db18d2
2 changed files with 6 additions and 1 deletions

View File

@ -251,6 +251,11 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
if (ivalue.isModule()) {
return importModule(ivalue.toModule());
}
if (ivalue.isNone()) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
return mlirOperationGetResult(operation, 0);
}
std::stringstream msg;
msg << "Unsupported ivalue: " << ivalue;
throw std::invalid_argument(msg.str());

View File

@ -27,7 +27,7 @@ class TestModule(torch.nn.Module):
# Modules with the same class can be selected between.
# CHECK: %[[MOD:.*]] = scf.if
s = self.s1 if b else self.s2
# CHECK: %[[N:.*]] = torch.prim.GetAttr %4["n"]
# CHECK: %[[N:.*]] = torch.prim.GetAttr %5["n"]
# CHECK: return %[[N]]
return s.n