Handle rank-0 annotations properly.

pull/241/head
Sean Silva 2021-06-23 10:24:53 -07:00
parent 145d4ae23c
commit 49b5b7272b
2 changed files with 12 additions and 4 deletions

View File

@ -508,12 +508,18 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
mlirLocationUnknownGet(context), *maybeDtype);
MlirType typeBound;
// `std::vector`'s `.data()` method can return nullptr when the
// size is 0. This triggers the "nothing known about sizes" case in
// the C API constructor, when we want the "we know we have 0 sizes"
// case. So use a dummy data pointer.
int64_t dummy;
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
if (hasValueSemantics) {
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
shape.data(), dtype);
shapeData, dtype);
} else {
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
shape.data(), dtype);
shapeData, dtype);
}
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(

View File

@ -14,7 +14,7 @@ mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
def forward(self, a, b):
return
test_module = TestModule()
@ -24,11 +24,13 @@ annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()
# CHECK: func private @__torch__.TestModule.forward(
# CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">,
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>},
# CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>}
# CHECK-SAME: ) -> !torch.none
annotator.annotateArgs(class_type, ['forward'], [
None,
((-1, 1024), torch.int8, True),
((), torch.float, True),
])
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.