mirror of https://github.com/llvm/torch-mlir
Handle rank-0 annotations properly.
parent
145d4ae23c
commit
49b5b7272b
|
@ -508,12 +508,18 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
|
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
|
||||||
mlirLocationUnknownGet(context), *maybeDtype);
|
mlirLocationUnknownGet(context), *maybeDtype);
|
||||||
MlirType typeBound;
|
MlirType typeBound;
|
||||||
|
// `std::vector`'s `.data()` method can return nullptr when the
|
||||||
|
// size is 0. This triggers the "nothing known about sizes" case in
|
||||||
|
// the C API constructor, when we want the "we know we have 0 sizes"
|
||||||
|
// case. So use a dummy data pointer.
|
||||||
|
int64_t dummy;
|
||||||
|
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
||||||
if (hasValueSemantics) {
|
if (hasValueSemantics) {
|
||||||
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
||||||
shape.data(), dtype);
|
shapeData, dtype);
|
||||||
} else {
|
} else {
|
||||||
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
||||||
shape.data(), dtype);
|
shapeData, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
||||||
|
|
|
@ -14,7 +14,7 @@ mb = torch_mlir.ModuleBuilder()
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
def forward(self, tensor):
|
def forward(self, a, b):
|
||||||
return
|
return
|
||||||
|
|
||||||
test_module = TestModule()
|
test_module = TestModule()
|
||||||
|
@ -24,11 +24,13 @@ annotator = torch_mlir.ClassAnnotator()
|
||||||
class_type = recursivescriptmodule._c._type()
|
class_type = recursivescriptmodule._c._type()
|
||||||
# CHECK: func private @__torch__.TestModule.forward(
|
# CHECK: func private @__torch__.TestModule.forward(
|
||||||
# CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">,
|
# CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">,
|
||||||
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}
|
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>},
|
||||||
|
# CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>}
|
||||||
# CHECK-SAME: ) -> !torch.none
|
# CHECK-SAME: ) -> !torch.none
|
||||||
annotator.annotateArgs(class_type, ['forward'], [
|
annotator.annotateArgs(class_type, ['forward'], [
|
||||||
None,
|
None,
|
||||||
((-1, 1024), torch.int8, True),
|
((-1, 1024), torch.int8, True),
|
||||||
|
((), torch.float, True),
|
||||||
])
|
])
|
||||||
|
|
||||||
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
|
Loading…
Reference in New Issue