mirror of https://github.com/llvm/torch-mlir
Fix for upstream Torch change.
After https://github.com/pytorch/pytorch/pull/65967 the `graph()` method is only available on `torch::jit::GraphFunction` now. Fixes https://github.com/llvm/torch-mlir/issues/388pull/390/head
parent
7e4ef74774
commit
b02b65cf6e
|
@ -63,8 +63,9 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
|||
appendToBlock, "std.return", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
};
|
||||
MlirBlock block =
|
||||
importBlock(context, function->graph()->block(), createTerminator);
|
||||
MlirBlock block = importBlock(
|
||||
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
||||
createTerminator);
|
||||
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
||||
return func;
|
||||
}
|
||||
|
|
|
@ -499,7 +499,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
// format, even though they still cause import issues when importing
|
||||
// through the larger Python session where they originate.
|
||||
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
|
||||
// std::cerr << *function->graph();
|
||||
// std::cerr << *torch::jit::toGraphFunction(function).graph();
|
||||
MethodAnnotation *annotation =
|
||||
annotator.getMethodAnnotationForFunction(function);
|
||||
MlirOperation func = importJitFunctionAsFuncOp(
|
||||
|
|
|
@ -245,7 +245,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||
auto methodName = node->s(c10::attr::name);
|
||||
torch::jit::Function *function = classType->findMethod(methodName);
|
||||
torch::jit::Block *calleeEntryBlock = function->graph()->block();
|
||||
torch::jit::Block *calleeEntryBlock =
|
||||
torch::jit::toGraphFunction(*function).graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return getMlirTypeFromTorchType(loc, v->type());
|
||||
});
|
||||
|
@ -263,7 +264,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
if (kind == c10::prim::CallFunction) {
|
||||
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
||||
torch::jit::Block *calleeEntryBlock =
|
||||
functionType->function()->graph()->block();
|
||||
torch::jit::toGraphFunction(*functionType->function()).graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return getMlirTypeFromTorchType(loc, v->type());
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue