diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 91ee4c14c..289e5722e 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -99,12 +99,12 @@ class ModelInfo: assert model_proto.graph, "Model must contain a main Graph" self.main_graph = GraphInfo(self, model_proto.graph) - def create_module(self, context: Optional[Context] = None) -> Operation: + def create_module(self, context: Optional[Context] = None) -> Module: if not context: context = Context() - module_op = Module.create(Location.unknown(context)) + module = Module.create(Location.unknown(context)) # TODO: Populate module level metadata from the ModelProto - return module_op + return module class GraphInfo: