From 49b5b7272bbdc46801826714e57dec5f984fd722 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 23 Jun 2021 10:24:53 -0700 Subject: [PATCH] Handle rank-0 annotations properly. --- frontends/pytorch/csrc/builder/ivalue_importer.cpp | 10 ++++++++-- .../ivalue_import/annotations/arg-tensor-type-bound.py | 6 ++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index f07a4f507..8b87ed873 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -508,12 +508,18 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { MlirType dtype = TypeMapper(context).mapFromTorchScalarType( mlirLocationUnknownGet(context), *maybeDtype); MlirType typeBound; + // `std::vector`'s `.data()` method can return nullptr when the + // size is 0. This triggers the "nothing known about sizes" case in + // the C API constructor, when we want the "we know we have 0 sizes" + // case. So use a dummy data pointer. + int64_t dummy; + int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data(); if (hasValueSemantics) { typeBound = npcompTorchValueTensorTypeGet(context, shape.size(), - shape.data(), dtype); + shapeData, dtype); } else { typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(), - shape.data(), dtype); + shapeData, dtype); } MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute( diff --git a/frontends/pytorch/test/ivalue_import/annotations/arg-tensor-type-bound.py b/frontends/pytorch/test/ivalue_import/annotations/arg-tensor-type-bound.py index c883fbd86..1b46128ec 100644 --- a/frontends/pytorch/test/ivalue_import/annotations/arg-tensor-type-bound.py +++ b/frontends/pytorch/test/ivalue_import/annotations/arg-tensor-type-bound.py @@ -14,7 +14,7 @@ mb = torch_mlir.ModuleBuilder() class TestModule(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, tensor): + def forward(self, a, b): return test_module = TestModule() @@ -24,11 +24,13 @@ annotator = torch_mlir.ClassAnnotator() class_type = recursivescriptmodule._c._type() # CHECK: func private @__torch__.TestModule.forward( # CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">, -# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>} +# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}, +# CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>} # CHECK-SAME: ) -> !torch.none annotator.annotateArgs(class_type, ['forward'], [ None, ((-1, 1024), torch.int8, True), + ((), torch.float, True), ]) # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.