mirror of https://github.com/llvm/torch-mlir
Add NoneType support for ivalue_importer
PyTorch added a Global variable `_is_full_backward_hook` recently. See https://github.com/pytorch/pytorch/pull/46163 Signed-off-by: Bairen Yi <yibairen.byron@bytedance.com>pull/161/head
parent
a38b7b72b2
commit
99d1db18d2
|
@ -251,6 +251,11 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
if (ivalue.isModule()) {
|
||||
return importModule(ivalue.toModule());
|
||||
}
|
||||
if (ivalue.isNone()) {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
std::stringstream msg;
|
||||
msg << "Unsupported ivalue: " << ivalue;
|
||||
throw std::invalid_argument(msg.str());
|
||||
|
|
|
@ -27,7 +27,7 @@ class TestModule(torch.nn.Module):
|
|||
# Modules with the same class can be selected between.
|
||||
# CHECK: %[[MOD:.*]] = scf.if
|
||||
s = self.s1 if b else self.s2
|
||||
# CHECK: %[[N:.*]] = torch.prim.GetAttr %4["n"]
|
||||
# CHECK: %[[N:.*]] = torch.prim.GetAttr %5["n"]
|
||||
# CHECK: return %[[N]]
|
||||
return s.n
|
||||
|
||||
|
|
Loading…
Reference in New Issue