//===- function_importer.cpp ----------------------------------------------===// // // This file is licensed under a pytorch-style license // See frontends/pytorch/LICENSE for license information. // //===----------------------------------------------------------------------===// #include "function_importer.h" #include #include "mlir_utils.h" #include "torch_to_mlir_utils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" namespace py = pybind11; using namespace torch_mlir; MlirOperation torch_mlir::importJitFunctionAsFuncOp( MlirContext context, torch::jit::Function *function, std::function getArgAttribute) { // Useful for debugging: // graph->dump(); MlirLocation loc = mlirLocationUnknownGet(context); MlirType functionType = getFunctionTypeFromSchema(context, function->getSchema()); // Use the function's qualified name from the compilation unit. // This is a stable linkage name that matches Python module lookup // conventions (see compilation unit import in IValueImporter for more details // on qualified names). MlirAttribute symNameAttr = mlirStringAttrGet( context, toMlirStringRef(function->qualname().qualifiedName())); MlirOperation func = createMlirOperation( "builtin.func", loc, mlirRegionCreate(), toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)), toMlirNamedAttribute("sym_name", symNameAttr)); std::vector argAttrDicts; for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) { MlirAttribute argAttrDict = getArgAttribute(i); if (mlirAttributeIsNull(argAttrDict)) { argAttrDicts.push_back(mlirDictionaryAttrGet(context, 0, nullptr)); } else { argAttrDicts.push_back(argAttrDict); } } mlirOperationSetAttributeByName( func, toMlirStringRef("arg_attrs"), mlirArrayAttrGet(context, argAttrDicts.size(), argAttrDicts.data())); MlirRegion bodyRegion = mlirOperationGetRegion(func, 0); std::vector resultTypes; for (int i = 0, e = mlirFunctionTypeGetNumResults(functionType); i != e; i++) { resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i)); } auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { createMlirOperationAtEnd( appendToBlock, "std.return", loc, derefineValues(yieldedValues, resultTypes, loc, appendToBlock)); }; MlirBlock block = importBlock(context, function->graph()->block(), createTerminator); mlirRegionAppendOwnedBlock(bodyRegion, block); return func; }