Fix for upstream Torch change.

After https://github.com/pytorch/pytorch/pull/65967 the `graph()` method
is only available on `torch::jit::GraphFunction` now.

Fixes https://github.com/llvm/torch-mlir/issues/388
pull/390/head
Sean Silva 2021-10-28 17:46:09 +00:00
parent 7e4ef74774
commit b02b65cf6e
3 changed files with 7 additions and 5 deletions

View File

@ -63,8 +63,9 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
appendToBlock, "std.return", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
MlirBlock block =
importBlock(context, function->graph()->block(), createTerminator);
MlirBlock block = importBlock(
context, torch::jit::toGraphFunction(*function).graph()->block(),
createTerminator);
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}

View File

@ -499,7 +499,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
// format, even though they still cause import issues when importing
// through the larger Python session where they originate.
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
// std::cerr << *function->graph();
// std::cerr << *torch::jit::toGraphFunction(function).graph();
MethodAnnotation *annotation =
annotator.getMethodAnnotationForFunction(function);
MlirOperation func = importJitFunctionAsFuncOp(

View File

@ -245,7 +245,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
auto classType = node->input(0)->type()->cast<c10::ClassType>();
auto methodName = node->s(c10::attr::name);
torch::jit::Function *function = classType->findMethod(methodName);
torch::jit::Block *calleeEntryBlock = function->graph()->block();
torch::jit::Block *calleeEntryBlock =
torch::jit::toGraphFunction(*function).graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return getMlirTypeFromTorchType(loc, v->type());
});
@ -263,7 +264,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
if (kind == c10::prim::CallFunction) {
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
torch::jit::Block *calleeEntryBlock =
functionType->function()->graph()->block();
torch::jit::toGraphFunction(*functionType->function()).graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return getMlirTypeFromTorchType(loc, v->type());
});