//===- node_importer.cpp --------------------------------------------------===// // // This file is licensed under a pytorch-style license // See frontends/pytorch/LICENSE for license information. // //===----------------------------------------------------------------------===// #include "node_importer.h" #include #include "mlir_utils.h" #include "op_builder.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" #include "npcomp-c/TorchTypes.h" namespace py = pybind11; using namespace torch_mlir; using Value = torch::jit::Value; using Block = torch::jit::Block; using Node = torch::jit::Node; namespace { class NodeImporter { public: NodeImporter(MlirContext context) : context(context) {} void importNode(Node *node, MlirBlock appendToBlock); MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator); private: MlirBlock createBlockFor(Block *jitBlock); void mapValue(Value *jitValue, MlirValue value); void mapResults(Node *node, MlirOperation operation); MlirValue lookupMappedValue(Value *jitValue); std::vector lookupMappedValues(c10::ArrayRef values); MlirContext context; std::unordered_map valueMap; }; } // namespace using InputsTransformFn = std::function(std::vector &)>; // The inputs of `DictConstruct` in TorchScript IR are in the order // like k0, v0, k1, v1. Rearrange them to put the key operands together and // then the value operands like k0, k1,v0, v1. This is the expected format by // the corresponding MLIR op. static std::vector rearrangeDictConstructInputs(std::vector &inputs) { if (inputs.empty()) return inputs; assert(inputs.size() % 2 == 0 && "DictConstruct must have even number of operands"); std::vector rearranged; std::vector values; for (auto it = inputs.begin(); it != inputs.end(); it++) { rearranged.push_back(*it); values.push_back(*++it); } rearranged.insert(rearranged.end(), values.begin(), values.end()); return rearranged; } void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { TypeMapper typeMapper(context); MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, InputsTransformFn t) { std::vector mappedInputs = lookupMappedValues(node->inputs()); MlirOperation operation = createMlirOperationAtEnd(appendToBlock, opName, loc, getMlirTypesFromValues(loc, node->outputs()), t ? t(mappedInputs) : mappedInputs); mapResults(node, operation); }; auto createAndMapNodeWithAttribute = [&](Node *node, const std::string &opName, const std::string &attrName, MlirAttribute attr) { MlirOperation operation = createMlirOperationAtEnd(appendToBlock, opName, loc, getMlirTypesFromValues(loc, node->outputs()), lookupMappedValues(node->inputs()), toMlirNamedAttribute(attrName.c_str(), attr)); mapResults(node, operation); }; // Trivial ops with schema. auto maybeSchema = node->maybeSchema(); if (maybeSchema) { MlirOperation operation = createOperationFromSchema(appendToBlock, loc, node->schema(), getMlirTypesFromValues(loc, node->outputs()), lookupMappedValues(node->inputs())); mapResults(node, operation); return; } // Builtin interpreter ops with no operator/schema. InputsTransformFn transformer = kind != c10::prim::DictConstruct ? nullptr : rearrangeDictConstructInputs; switch (kind) { case c10::prim::ListUnpack: case c10::prim::ListConstruct: case c10::prim::TupleConstruct: case c10::prim::DictConstruct: { createAndMapTrivialNode( node, "torch.prim." + std::string(kind.toUnqualString()), transformer); return; } case c10::prim::GetAttr: case c10::prim::SetAttr: { createAndMapNodeWithAttribute( node, "torch.prim." + std::string(kind.toUnqualString()), "name", importAttribute(loc, node, c10::attr::name)); return; } } if (kind == c10::prim::Constant) { auto output = node->output(); MlirOperation op; OpBuilder builder(context); if (output->type()->cast()) { op = builder.createNoneConstant(loc); } else if (output->type()->cast()) { op = builder.createBoolConstant( loc, static_cast(node->i(c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.int", loc, typeMapper.mapFromTorchType(loc, output->type()), toMlirNamedAttribute("value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.float", loc, typeMapper.mapFromTorchType(loc, output->type()), toMlirNamedAttribute("value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.str", loc, npcompTorchStringTypeGet(context), toMlirNamedAttribute( "value", mlirStringAttrGet(context, toMlirStringRef(node->s( c10::attr::value))))); } else if (auto functionType = output->type()->cast()) { torch::jit::Function *function = functionType->function(); const std::string &symName = function->qualname().qualifiedName(); op = createMlirOperation( "std.constant", loc, getFunctionTypeFromSchema(context, function->getSchema()), toMlirNamedAttribute( "value", mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); } else { std::stringstream msg; msg << "unhandled prim::Constant node: "; node->print(msg, 0, nullptr); mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str()); throw mlir_diagnostic_emitted(); } mlirBlockAppendOwnedOperation(appendToBlock, op); mapResults(node, op); return; } if (kind == c10::prim::Loop) { std::vector resultTypes = getMlirTypesFromValues(loc, node->outputs()); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.Loop", loc, resultTypes, lookupMappedValues(node->inputs().slice(0, 2)), derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes, loc, appendToBlock), mlirRegionCreate()); mapResults(node, operation); std::vector terminatorOperandTypes = { npcompTorchBoolTypeGet(context)}; terminatorOperandTypes.insert(terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end()); auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc, derefineValues(yieldedValues, terminatorOperandTypes, loc, appendToBlock)); }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), importBlock(node->blocks()[0], createTerminator)); return; } if (kind == c10::prim::If) { std::vector resultTypes = getMlirTypesFromValues(loc, node->outputs()); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()), resultTypes, mlirRegionCreate(), mlirRegionCreate()); mapResults(node, operation); auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { createMlirOperationAtEnd( appendToBlock, "torch.prim.If.yield", loc, derefineValues(yieldedValues, resultTypes, loc, appendToBlock)); }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), importBlock(node->blocks()[0], createTerminator)); mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 1), importBlock(node->blocks()[1], createTerminator)); return; } if (kind == c10::prim::CallMethod) { auto classType = node->input(0)->type()->cast(); auto methodName = node->s(c10::attr::name); torch::jit::Function *function = classType->findMethod(methodName); torch::jit::Block *calleeEntryBlock = function->graph()->block(); auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { return typeMapper.mapFromTorchType(loc, v->type()); }); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.CallMethod", loc, getMlirTypesFromValues(loc, node->outputs()), derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc, appendToBlock), toMlirNamedAttribute("name", importAttribute(loc, node, c10::attr::name))); mapResults(node, operation); return; } if (kind == c10::prim::CallFunction) { auto functionType = node->input(0)->type()->cast(); torch::jit::Block *calleeEntryBlock = functionType->function()->graph()->block(); auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { return typeMapper.mapFromTorchType(loc, v->type()); }); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "std.call_indirect", loc, getMlirTypesFromValues(loc, node->outputs()), lookupMappedValue(node->input(0)), derefineValues(lookupMappedValues(node->inputs().slice(1)), expectedTypes, loc, appendToBlock)); mapResults(node, operation); return; } { std::stringstream msg; msg << "unhandled: could not import node: "; node->print(msg, 0, nullptr); mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str()); throw mlir_diagnostic_emitted(); } } MlirBlock NodeImporter::importBlock(Block *jitBlock, CreateTerminatorFn createTerminator) { MlirBlock block = createBlockFor(jitBlock); for (Node *node : jitBlock->nodes()) { importNode(node, block); } Node *returnNode = jitBlock->return_node(); createTerminator(lookupMappedValues(returnNode->inputs()), block); return block; } MlirBlock NodeImporter::createBlockFor(Block *jitBlock) { Node *paramNode = jitBlock->param_node(); MlirLocation loc = getMlirLocationFromNode(context, paramNode); std::vector blockArgTypes = getMlirTypesFromValues(loc, paramNode->outputs()); MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data()); for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { Value *jitValue = paramNode->outputs()[i]; MlirValue value = mlirBlockGetArgument(block, i); mapValue(jitValue, value); } return block; } void NodeImporter::mapValue(Value *jitValue, MlirValue value) { auto it = valueMap.find(jitValue); (void)it; assert(it == valueMap.end() && "jitValue has already been mapped"); valueMap[jitValue] = value; } void NodeImporter::mapResults(Node *node, MlirOperation operation) { assert(node->outputs().size() == (size_t)mlirOperationGetNumResults(operation)); for (int i = 0, e = node->outputs().size(); i < e; i++) { mapValue(node->outputs()[i], mlirOperationGetResult(operation, i)); } } MlirValue NodeImporter::lookupMappedValue(Value *jitValue) { auto it = valueMap.find(jitValue); assert(it != valueMap.end() && "trying to get mapping for jitValue that is not mapped yet!"); return it->second; } std::vector NodeImporter::lookupMappedValues(c10::ArrayRef values) { std::vector ret; for (Value *value : values) { ret.push_back(lookupMappedValue(value)); } return ret; } MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock, CreateTerminatorFn createTerminator) { NodeImporter importer(context); return importer.importBlock(jitBlock, createTerminator); }