mirror of https://github.com/llvm/torch-mlir
Handle rank-0 annotations properly.
parent
145d4ae23c
commit
49b5b7272b
|
@ -508,12 +508,18 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
|
||||
mlirLocationUnknownGet(context), *maybeDtype);
|
||||
MlirType typeBound;
|
||||
// `std::vector`'s `.data()` method can return nullptr when the
|
||||
// size is 0. This triggers the "nothing known about sizes" case in
|
||||
// the C API constructor, when we want the "we know we have 0 sizes"
|
||||
// case. So use a dummy data pointer.
|
||||
int64_t dummy;
|
||||
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
||||
if (hasValueSemantics) {
|
||||
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
shapeData, dtype);
|
||||
} else {
|
||||
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
shapeData, dtype);
|
||||
}
|
||||
|
||||
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
||||
|
|
|
@ -14,7 +14,7 @@ mb = torch_mlir.ModuleBuilder()
|
|||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, tensor):
|
||||
def forward(self, a, b):
|
||||
return
|
||||
|
||||
test_module = TestModule()
|
||||
|
@ -24,11 +24,13 @@ annotator = torch_mlir.ClassAnnotator()
|
|||
class_type = recursivescriptmodule._c._type()
|
||||
# CHECK: func private @__torch__.TestModule.forward(
|
||||
# CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">,
|
||||
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}
|
||||
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>},
|
||||
# CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>}
|
||||
# CHECK-SAME: ) -> !torch.none
|
||||
annotator.annotateArgs(class_type, ['forward'], [
|
||||
None,
|
||||
((-1, 1024), torch.int8, True),
|
||||
((), torch.float, True),
|
||||
])
|
||||
|
||||
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
|
|
Loading…
Reference in New Issue