diff --git a/frontends/pytorch/csrc/CMakeLists.txt b/frontends/pytorch/csrc/CMakeLists.txt index a7e16e98c..1fad264d8 100644 --- a/frontends/pytorch/csrc/CMakeLists.txt +++ b/frontends/pytorch/csrc/CMakeLists.txt @@ -13,6 +13,13 @@ include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ${Python3_INCLUDE_DIRS} + # TODO: Fix implicit ordering. If PyTorch was build against an external + # pybind11, then it will not be in the above search path and must be + # resolved here, in the hope that it is the same one we were configured + # with (which it should be if installed via pip). This is really fragile, + # though, causing cast failures at runtime if we get it wrong. Come up with + # a way to strengthen this. + ${pybind11_INCLUDE_DIR} ) link_directories("${TORCH_INSTALL_PREFIX}/lib") @@ -28,7 +35,11 @@ target_link_libraries(NPCOMPTorchMLIRExt # NPCOMP shared library. NPCOMP ) - +add_dependencies(NPCOMPTorchMLIRExt + # Uses of the torch_mlir extension also require the npcomp extension to + # be built. + NPCOMPNativePyExt +) set_target_properties(NPCOMPTorchMLIRExt PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/python OUTPUT_NAME _torch_mlir diff --git a/frontends/pytorch/csrc/builder/CMakeLists.txt b/frontends/pytorch/csrc/builder/CMakeLists.txt index 0342d3713..c89104d1f 100644 --- a/frontends/pytorch/csrc/builder/CMakeLists.txt +++ b/frontends/pytorch/csrc/builder/CMakeLists.txt @@ -5,12 +5,20 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ${Python3_INCLUDE_DIRS} + # TODO: Fix implicit ordering. If PyTorch was build against an external + # pybind11, then it will not be in the above search path and must be + # resolved here, in the hope that it is the same one we were configured + # with (which it should be if installed via pip). This is really fragile, + # though, causing cast failures at runtime if we get it wrong. Come up with + # a way to strengthen this. + ${pybind11_INCLUDE_DIR} ) link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(npcomp_torch_builder_bindings acap_dispatch.cpp debug.cpp func_builder.cpp + graph_importer.cpp module_builder.cpp python_bindings.cpp ) diff --git a/frontends/pytorch/csrc/builder/acap_dispatch.cpp b/frontends/pytorch/csrc/builder/acap_dispatch.cpp index e0552f308..9bb09c3ca 100644 --- a/frontends/pytorch/csrc/builder/acap_dispatch.cpp +++ b/frontends/pytorch/csrc/builder/acap_dispatch.cpp @@ -42,63 +42,17 @@ static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY; static c10::DispatchKey kAcapGradDispatchKey = c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY; -AcapController::KernelCallBuilder::KernelCallBuilder( +AcapController::TracedKernelCallBuilder::TracedKernelCallBuilder( AcapController &parent, MlirContext context, MlirLocation loc, const c10::OperatorHandle &opHandle, llvm::Optional overrideKernelName) - : parent(parent), context(context), loc(loc), opHandle(opHandle), - state("torch.kernel_call", loc) { - (void)this->context; // Preserve for future. - const std::string &kernelName = - overrideKernelName ? *overrideKernelName : opHandle.operator_name().name; - MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet( - "kernelName", - mlirStringAttrGet(context, kernelName.size(), kernelName.data())); - mlirOperationStateAddAttributes(state, 1, &kernelNameAttr); - addSchemaAttrs(); -} + : KernelCallBuilder(context, loc, + overrideKernelName ? *overrideKernelName + : opHandle.operator_name().name, + opHandle.schema()), + parent(parent), opHandle(opHandle) {} -void AcapController::KernelCallBuilder::addSchemaAttrs() { - // Map the op schema to the kernel_call attributes: - // sigArgTypes - // sigRetTypes - // sigIsVararg - // sigIsVarret - // sigIsMutable - const c10::FunctionSchema &schema = opHandle.schema(); - llvm::SmallVector attrs; - attrs.push_back(mlirNamedAttributeGet( - "sigIsMutable", mlirBoolAttrGet(context, schema.is_mutable()))); - attrs.push_back(mlirNamedAttributeGet( - "sigIsVararg", mlirBoolAttrGet(context, schema.is_vararg()))); - attrs.push_back(mlirNamedAttributeGet( - "sigIsVarret", mlirBoolAttrGet(context, schema.is_varret()))); - - // Arg types. - llvm::SmallVector args; - for (auto &arg : schema.arguments()) { - const std::string &typeStr = arg.type()->str(); - args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data())); - } - attrs.push_back(mlirNamedAttributeGet( - "sigArgTypes", mlirArrayAttrGet(context, args.size(), args.data()))); - - // Return types. - llvm::SmallVector returns; - for (auto &ret : schema.returns()) { - const std::string &typeStr = ret.type()->str(); - returns.push_back( - mlirStringAttrGet(context, typeStr.size(), typeStr.data())); - } - attrs.push_back(mlirNamedAttributeGet( - "sigRetTypes", - mlirArrayAttrGet(context, returns.size(), returns.data()))); - - // Add attrs to op. - mlirOperationStateAddAttributes(state, attrs.size(), attrs.data()); -} - -void AcapController::KernelCallBuilder::addOperand(const IValue &value) { +void AcapController::TracedKernelCallBuilder::addOperand(const IValue &value) { MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value); if (mlirValueIsNull(mlirValue)) { std::stringstream out; @@ -107,10 +61,10 @@ void AcapController::KernelCallBuilder::addOperand(const IValue &value) { << value.tagKind() << "): " << value; throw std::invalid_argument(out.str()); } - mlirOperationStateAddOperands(state, 1, &mlirValue); + KernelCallBuilder::addOperand(mlirValue); } -void AcapController::KernelCallBuilder::addResult(const IValue &value) { +void AcapController::TracedKernelCallBuilder::addResult(const IValue &value) { MlirType resultType = parent.mapIValueToMlirType(loc, value); if (mlirTypeIsNull(resultType)) { std::stringstream out; @@ -122,12 +76,11 @@ void AcapController::KernelCallBuilder::addResult(const IValue &value) { if (value.isTensor()) { resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor()); } - mlirOperationStateAddResults(state, 1, &resultType); + KernelCallBuilder::addResultType(resultType); } -MlirOperation AcapController::KernelCallBuilder::create() { - // Create operation. - MlirOperation op = state.createOperation(); +MlirOperation AcapController::TracedKernelCallBuilder::create() { + MlirOperation op = KernelCallBuilder::create(); parent.funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op); // Map result tensors. @@ -264,7 +217,7 @@ at::Tensor AcapController::convolutionKernel( MlirContext context = current->funcBuilder->getContext(); MlirLocation loc = current->getCurrentLocation(); std::string kernelName{"aten::convolution"}; - KernelCallBuilder callBuilder{*current, context, loc, *opHandle}; + TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle}; callBuilder.addOperand(IValue(input)); callBuilder.addOperand(IValue(weight)); @@ -344,8 +297,8 @@ AcapController::mklConvolutionBackward( ""}; auto emitOpHandle = dispatcher.findOp(emitOpName); assert(emitOpHandle && "could not find convolution_backward_overrideable op"); - KernelCallBuilder callBuilder{*current, context, loc, *emitOpHandle, - kernelName}; + TracedKernelCallBuilder callBuilder{*current, context, loc, *emitOpHandle, + kernelName}; callBuilder.addOperand(IValue(grad_output)); callBuilder.addOperand(IValue(input)); @@ -398,7 +351,7 @@ at::Tensor &AcapController::copyUnderKernel(at::Tensor &self, MlirContext context = current->funcBuilder->getContext(); MlirLocation loc = current->getCurrentLocation(); - KernelCallBuilder callBuilder{*current, context, loc, *opHandle}; + TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle}; callBuilder.addOperand(IValue(self)); callBuilder.addOperand(IValue(src)); @@ -462,7 +415,7 @@ void AcapController::fallbackKernelImpl( MlirContext context = funcBuilder->getContext(); MlirLocation loc = getCurrentLocation(); auto kernelName = schema.name(); - KernelCallBuilder callBuilder{*this, context, loc, opHandle}; + TracedKernelCallBuilder callBuilder{*this, context, loc, opHandle}; // Map arguments to operands. // This must be accumulated into the OperationState prior to re-dispatch @@ -553,7 +506,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc, MlirType AcapController::mapIValueToMlirType(MlirLocation loc, const IValue &ival) { if (ival.isScalar()) { - return typeMapper.mapScalarType(ival.toScalar().type()); + return typeMapper.mapFromTorchScalarType(ival.toScalar().type()); } if (ival.isTensor()) { return typeMapper.forwardTensorToType(ival.toTensor()); @@ -601,7 +554,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) { // Basicpy bool type. elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1); } else { - elementType = typeMapper.mapScalarType(tensor.scalar_type()); + elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type()); } llvm::SmallVector shape(tensor.sizes().begin(), tensor.sizes().end()); diff --git a/frontends/pytorch/csrc/builder/acap_dispatch.h b/frontends/pytorch/csrc/builder/acap_dispatch.h index 9d7480fa1..0a9443e0b 100644 --- a/frontends/pytorch/csrc/builder/acap_dispatch.h +++ b/frontends/pytorch/csrc/builder/acap_dispatch.h @@ -80,9 +80,9 @@ public: private: /// Builds a kernel call step by step. - class KernelCallBuilder { + class TracedKernelCallBuilder : private KernelCallBuilder { public: - KernelCallBuilder( + TracedKernelCallBuilder( AcapController &parent, MlirContext context, MlirLocation loc, const c10::OperatorHandle &opHandle, llvm::Optional overrideKernelName = llvm::None); @@ -91,12 +91,8 @@ private: MlirOperation create(); private: - void addSchemaAttrs(); AcapController &parent; - MlirContext context; - MlirLocation loc; const c10::OperatorHandle &opHandle; - OperationStateHolder state; int resultCount = 0; llvm::SmallVector, 4> resultIndexToTensorMap; }; diff --git a/frontends/pytorch/csrc/builder/func_builder.cpp b/frontends/pytorch/csrc/builder/func_builder.cpp index 3e098cb50..f8a79fadd 100644 --- a/frontends/pytorch/csrc/builder/func_builder.cpp +++ b/frontends/pytorch/csrc/builder/func_builder.cpp @@ -7,6 +7,7 @@ #include "func_builder.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" #include "npcomp-c/Types.h" @@ -22,7 +23,90 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type, return s.createOperation(); } -MlirType TypeMapper::mapScalarType(c10::ScalarType scalarType) { +KernelCallBuilder::KernelCallBuilder(MlirContext context, MlirLocation loc, + llvm::StringRef kernelName, + const c10::FunctionSchema &schema) + : context(context), loc(loc), state("torch.kernel_call", loc), + kernelName(kernelName), schema(schema) { + (void)this->context; // Preserve for future. + MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet( + "kernelName", + mlirStringAttrGet(context, kernelName.size(), kernelName.data())); + mlirOperationStateAddAttributes(state, 1, &kernelNameAttr); + addSchemaAttrs(); +} + +void KernelCallBuilder::addSchemaAttrs() { + // Map the op schema to the kernel_call attributes: + // sigArgTypes + // sigRetTypes + // sigIsVararg + // sigIsVarret + // sigIsMutable + llvm::SmallVector attrs; + attrs.push_back(mlirNamedAttributeGet( + "sigIsMutable", mlirBoolAttrGet(context, schema.is_mutable()))); + attrs.push_back(mlirNamedAttributeGet( + "sigIsVararg", mlirBoolAttrGet(context, schema.is_vararg()))); + attrs.push_back(mlirNamedAttributeGet( + "sigIsVarret", mlirBoolAttrGet(context, schema.is_varret()))); + + // Arg types. + llvm::SmallVector args; + for (auto &arg : schema.arguments()) { + const std::string &typeStr = arg.type()->str(); + args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data())); + } + attrs.push_back(mlirNamedAttributeGet( + "sigArgTypes", mlirArrayAttrGet(context, args.size(), args.data()))); + + // Return types. + llvm::SmallVector returns; + for (auto &ret : schema.returns()) { + const std::string &typeStr = ret.type()->str(); + returns.push_back( + mlirStringAttrGet(context, typeStr.size(), typeStr.data())); + } + attrs.push_back(mlirNamedAttributeGet( + "sigRetTypes", + mlirArrayAttrGet(context, returns.size(), returns.data()))); + + // Add attrs to op. + mlirOperationStateAddAttributes(state, attrs.size(), attrs.data()); +} + +void KernelCallBuilder::addOperand(MlirValue operand) { + mlirOperationStateAddOperands(state, 1, &operand); +} + +void KernelCallBuilder::addResultType(MlirType resultType) { + mlirOperationStateAddResults(state, 1, &resultType); +} + +MlirOperation KernelCallBuilder::create() { return state.createOperation(); } + +MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) { + auto type = rawMapFromTorchScalarType(scalarType); + if (mlirTypeIsNull(type)) { + std::stringstream message; + message << "unsupported PyTorch scalar type: " << c10::toString(scalarType); + throw std::invalid_argument(message.str()); + } + return type; +} + +MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc, + c10::ScalarType scalarType) { + auto type = rawMapFromTorchScalarType(scalarType); + if (mlirTypeIsNull(type)) { + std::stringstream message; + message << "unsupported PyTorch scalar type: " << c10::toString(scalarType); + mlirEmitError(loc, message.str().c_str()); + } + return type; +} + +MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) { using c10::ScalarType; switch (scalarType) { case ScalarType::Byte: @@ -49,10 +133,49 @@ MlirType TypeMapper::mapScalarType(c10::ScalarType scalarType) { return mlirBF16TypeGet(context); case ScalarType::Half: return mlirF16TypeGet(context); + default: { + return {nullptr}; + } + } +} + +MlirType TypeMapper::mapFromTorchType(MlirLocation loc, + const c10::TypePtr &torchType) { + using c10::TypeKind; + auto kind = torchType->kind(); + switch (kind) { + case TypeKind::TensorType: { + auto tensorType = torchType->cast(); + // Element type. + MlirType elementType; + if (tensorType->scalarType()) { + elementType = mapFromTorchScalarType(loc, *tensorType->scalarType()); + if (mlirTypeIsNull(elementType)) + return {nullptr}; + } else { + elementType = npcompAnyDtypeTypeGet(context); + } + // Sizes. + auto &sizes = tensorType->symbolic_sizes(); + if (!sizes.rank()) { + // Unranked. + return npcompNdArrayTypeGetUnranked(elementType); + } + // Ranked with possibly dynamic dims. + auto &symbolicShape = tensorType->symbolic_sizes(); + llvm::SmallVector dims; + dims.resize(*sizes.rank()); + for (size_t i = 0; i < dims.size(); ++i) { + auto shapeSymbol = symbolicShape[i]; + dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1; + } + return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType); + } default: { std::stringstream message; - message << "unsupported PyTorch scalar type: " << c10::toString(scalarType); - throw std::invalid_argument(message.str()); + message << "unable to map Torch type " << torchType << " to MLIR type"; + mlirEmitError(loc, message.str().c_str()); + return {nullptr}; } } } @@ -64,7 +187,7 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) { return npcompNoneTypeGet(context); } - MlirType elementType = mapScalarType(tensor.scalar_type()); + MlirType elementType = mapFromTorchScalarType(tensor.scalar_type()); // TODO: Decide when it is necessary to take strides into account. Right now, // just erase them and let the compiler decide. @@ -73,9 +196,10 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) { } std::unique_ptr -FuncBuilder::createFunction(MlirContext context, MlirLocation location, - llvm::StringRef name, +FuncBuilder::createFunction(FuncBuilder::Inserter &inserter, + MlirLocation location, llvm::StringRef name, llvm::SmallVectorImpl &inputTypes) { + auto context = mlirLocationGetContext(location); // TODO: Create a dedicated API upstream for creating/manipulating func ops. // (this is fragile and reveals details that are not guaranteed). llvm::SmallVector funcAttrs; @@ -102,6 +226,7 @@ FuncBuilder::createFunction(MlirContext context, MlirLocation location, MlirRegion bodyRegion = mlirOperationGetRegion(funcOp, 0); MlirBlock entryBlock = mlirRegionGetFirstBlock(bodyRegion); + inserter(funcOp); return std::unique_ptr(new FuncBuilder( context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true))); } diff --git a/frontends/pytorch/csrc/builder/func_builder.h b/frontends/pytorch/csrc/builder/func_builder.h index 093a559f1..0e1dda10d 100644 --- a/frontends/pytorch/csrc/builder/func_builder.h +++ b/frontends/pytorch/csrc/builder/func_builder.h @@ -13,6 +13,7 @@ #include "llvm/ADT/StringRef.h" #include +#include namespace torch_mlir { @@ -51,21 +52,30 @@ public: /// Gets a corresponding MlirType for the Torch ScalarType. /// Throws std::invalid_argument on failure. - MlirType mapScalarType(c10::ScalarType scalarType); + MlirType mapFromTorchScalarType(c10::ScalarType scalarType); /// Gets a corresponding MlirType for the forward component of a tensor. /// Throws std::invalid_argument on failure. MlirType forwardTensorToType(at::Tensor tensor); + /// Gets a corresponding MlirType for the Torch ScalarType. + /// Returns a null type on failure and emits a diagnostic. + MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType); + + /// Maps a torch type to a corresponding MlirType. Returns a null type + /// on failure and emits a diagnostic. + MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType); + private: + /// Maps from a scalar type and returns null if no match (no other error + /// reporting). + MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType); MlirContext context; }; /// Wraps an MlirBlock under construction, primarily tracking the terminator /// and supporting manipulation of it. The terminator may be null if it has -/// not yet been constructed, although, for entry blocks, we always construct -/// the function with an appropriate return terminator (which can be changed -/// later). +/// not yet been constructed. class BlockBuilder { public: BlockBuilder(MlirBlock block, MlirOperation terminator, bool isReturn) @@ -86,15 +96,39 @@ private: bool isReturn; }; +/// Builds a kernel call step by step. +class KernelCallBuilder { +public: + KernelCallBuilder(MlirContext context, MlirLocation loc, + llvm::StringRef kernelName, + const c10::FunctionSchema &schema); + void addOperand(MlirValue operand); + void addResultType(MlirType resultType); + MlirOperation create(); + +protected: + MlirContext context; + MlirLocation loc; + +private: + void addSchemaAttrs(); + OperationStateHolder state; + llvm::StringRef kernelName; + const c10::FunctionSchema &schema; +}; + /// Wraps a 'func' MlirOperation and provides facilities for constructing /// IR from some stream of Torch operations. class FuncBuilder { public: + /// Callback for inserting a function. + using Inserter = std::function; + /// Creates a new func op with the given characteristics. The created /// operation is not attached. The caller must either destroy it or add it /// to a parent. static std::unique_ptr - createFunction(MlirContext context, MlirLocation location, + createFunction(Inserter &inserter, MlirLocation location, llvm::StringRef name, llvm::SmallVectorImpl &inputTypes); diff --git a/frontends/pytorch/csrc/builder/graph_importer.cpp b/frontends/pytorch/csrc/builder/graph_importer.cpp new file mode 100644 index 000000000..94dc47b97 --- /dev/null +++ b/frontends/pytorch/csrc/builder/graph_importer.cpp @@ -0,0 +1,303 @@ +//===- graph_importer.cpp -------------------------------------------------===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "graph_importer.h" + +#include "mlir-c/Diagnostics.h" +#include "mlir-c/StandardAttributes.h" +#include "mlir-c/StandardTypes.h" + +namespace py = pybind11; +using namespace torch_mlir; + +//------------------------------------------------------------------------------ +// GraphImporter::NodeScope implementation +//------------------------------------------------------------------------------ + +// A scope of Graph Value * to corresponding MlirValue. Scopes nest +// region-wise. Note that in PyTorch, the thing called 'Block' is analagous +// to a capturing MLIR region. +class GraphImporter::NodeScope { +public: + NodeScope() = default; + NodeScope(NodeScope *prev) : prev(prev) {} + + void bindValue(torch::jit::Value *torchValue, MlirValue value); + MlirValue findValue(torch::jit::Value *torchValue); + MlirValue findRequiredValue(MlirLocation loc, torch::jit::Value *torchValue); + +private: + llvm::DenseMap valueMap; + NodeScope *prev = nullptr; +}; + +void GraphImporter::NodeScope::bindValue(torch::jit::Value *torchValue, + MlirValue value) { + assert(valueMap.count(torchValue) == 0 && "duplicate torch Value bind"); + valueMap[torchValue] = value; +} + +MlirValue GraphImporter::NodeScope::findValue(torch::jit::Value *torchValue) { + auto foundIt = valueMap.find(torchValue); + if (foundIt == valueMap.end()) { + if (prev) + return prev->findValue(torchValue); + else + return {nullptr}; + } + return foundIt->second; +} + +MlirValue +GraphImporter::NodeScope::findRequiredValue(MlirLocation loc, + torch::jit::Value *torchValue) { + MlirValue value = findValue(torchValue); + if (mlirValueIsNull(value)) { + std::stringstream msg; + msg << "internal error: unmapped torch value: %" << torchValue->debugName(); + mlirEmitError(loc, msg.str().c_str()); + throw mlir_diagnostic_emitted(); + } + return value; +} + +//------------------------------------------------------------------------------ +// GraphImporter::NodeImporter implementation +//------------------------------------------------------------------------------ + +/// Helper class to import a torch::jit::Node into an MLIR function. +/// This class primarily exists to eliminate the need for large lists of +/// carried arguments related to doing the import. +class GraphImporter::NodeImporter { +public: + NodeImporter(torch::jit::Node *node, GraphImporter &parent, + FuncBuilder *funcBuilder, MlirBlock block, MlirOperation ip, + NodeScope *scope); + + void importNode(); + void importReturnOp(); + +private: + MlirContext context() { return parent.context(); } + void importPrimNode(); + MlirAttribute importValueAttribute(); + + torch::jit::Node *node; + GraphImporter &parent; + FuncBuilder *funcBuilder; + MlirBlock block; + MlirOperation ip; + NodeScope *scope; + MlirLocation loc; +}; + +GraphImporter::NodeImporter::NodeImporter(torch::jit::Node *node, + GraphImporter &parent, + FuncBuilder *funcBuilder, + MlirBlock block, MlirOperation ip, + NodeScope *scope) + : node(node), parent(parent), funcBuilder(funcBuilder), block(block), + ip(ip), scope(scope) { + loc = parent.extractCallstackLoc(node); +} + +void GraphImporter::NodeImporter::importNode() { + // Prim namespace handled specially. + auto kind = node->kind(); + if (kind.ns() == c10::namespaces::prim) { + importPrimNode(); + return; + } + + // Generic import. + auto funcSchema = node->maybeSchema(); + if (funcSchema) { + KernelCallBuilder kcb(context(), loc, kind.toQualString(), *funcSchema); + for (auto *input : node->inputs()) { + kcb.addOperand(scope->findRequiredValue(loc, input)); + } + for (auto *output : node->outputs()) { + MlirType type = + parent.type_mapper().mapFromTorchType(loc, output->type()); + if (mlirTypeIsNull(type)) { + throw mlir_diagnostic_emitted(); + } + kcb.addResultType(type); + } + MlirOperation op = kcb.create(); + mlirBlockInsertOwnedOperationBefore(block, ip, op); + + // Map results. + for (auto it : llvm::enumerate(node->outputs())) { + scope->bindValue(it.value(), mlirOperationGetResult(op, it.index())); + } + return; + } + + // No soup for you. Not exactly sure when this can happen. + { + std::stringstream msg; + msg << "unhandled: generic operation " << kind.toDisplayString(); + mlirEmitError(loc, msg.str().c_str()); + throw mlir_diagnostic_emitted(); + } +} + +void GraphImporter::NodeImporter::importReturnOp() { + OperationStateHolder s("std.return", loc); + llvm::SmallVector returnsValues; + for (auto *input : node->inputs()) { + returnsValues.push_back(scope->findRequiredValue(loc, input)); + } + mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data()); + mlirBlockInsertOwnedOperationBefore(block, ip, s.createOperation()); +} + +void GraphImporter::NodeImporter::importPrimNode() { + auto kind = node->kind(); + if (kind == c10::prim::Constant) { + auto output = node->output(); + MlirAttribute valueAttr = importValueAttribute(); + MlirValue constValue = funcBuilder->getGeneralConstant(loc, valueAttr); + scope->bindValue(output, constValue); + return; + } + + // Unhandled. + { + std::stringstream msg; + msg << "unhandled: prim operation " << kind.toDisplayString(); + mlirEmitError(loc, msg.str().c_str()); + throw mlir_diagnostic_emitted(); + } +} + +MlirAttribute GraphImporter::NodeImporter::importValueAttribute() { + using torch::jit::AttributeKind; + auto s = c10::attr::value; + auto kind = node->kindOf(s); + switch (kind) { + case AttributeKind::i: + // TODO: This should be a signed int once we have a constant op that can + // do that. + return mlirIntegerAttrGet(mlirIntegerTypeGet(context(), 64), node->i(s)); + break; + case AttributeKind::f: + return mlirFloatAttrDoubleGet(context(), mlirF64TypeGet(context()), + node->f(s)); + break; + + default: { + std::stringstream msg; + msg << "unhandled: value attribute kind " << toString(kind); + mlirEmitError(loc, msg.str().c_str()); + throw mlir_diagnostic_emitted(); + } + } +} + +//------------------------------------------------------------------------------ +// GraphImporter implementation +//------------------------------------------------------------------------------ + +GraphImporter::GraphImporter(std::shared_ptr graph, + MlirMappingOptions mappingOptions) + : graph(std::move(graph)), mappingOptions(std::move(mappingOptions)) {} + +std::shared_ptr GraphImporter::forPythonJitFunc( + torch::jit::Function *function, + GraphImporter::MlirMappingOptions mappingOptions) { + // Disallow an attempt to compile a native function. + if (!function->isGraphFunction()) { + throw std::invalid_argument( + "Expected a torch.jit.ScriptFunction with a graph"); + } + auto graph = function->graph(); + if (!mappingOptions.genericFuncName) { + mappingOptions.genericFuncName = function->name() + "$generic"; + } + if (!mappingOptions.funcName) { + mappingOptions.funcName = function->name() + "$generic"; + } + return std::make_shared(graph, std::move(mappingOptions)); +} + +void GraphImporter::initialize() { + defaultLoc = mlirLocationUnknownGet(context()); + // There is not a callstack associated with the graph so, try to grab + // a location from the first node that has one as a better than nothing + // thing. + // TODO: This doesn't actually seem to be working. Investigate when more + // examples are built out. + for (auto *node : graph->nodes()) { + MlirLocation nodeLoc = extractCallstackLoc(node, /*useDefault=*/false); + if (nodeLoc.ptr) { + defaultLoc = nodeLoc; + break; + } + } + + // Map inputs. + MlirLocation inputLoc = extractCallstackLoc(graph->param_node()); + for (const auto &input : graph->inputs()) { + MlirType t = type_mapper().mapFromTorchType(inputLoc, input->type()); + if (mlirTypeIsNull(t)) + throw mlir_diagnostic_emitted("could not convert function input type"); + genericFuncArgTypes.push_back(t); + } + + // Map outputs. + MlirLocation outputLoc = extractCallstackLoc(graph->return_node()); + for (const auto &output : graph->outputs()) { + MlirType t = type_mapper().mapFromTorchType(outputLoc, output->type()); + if (mlirTypeIsNull(t)) + throw mlir_diagnostic_emitted("could not convert function output type"); + genericFuncReturnTypes.push_back(t); + } +} + +void GraphImporter::importGenericFunc() { + auto funcBuilder = FuncBuilder::createFunction( + mappingOptions.inserter, defaultLoc, *mappingOptions.genericFuncName, + genericFuncArgTypes); + funcBuilder->rewriteFuncReturnTypes(genericFuncReturnTypes); + MlirBlock entryBlock = funcBuilder->getEntryBlock(); + + // Bind inputs. + NodeScope scope; + for (const auto &it : llvm::enumerate(graph->inputs())) { + MlirValue value = mlirBlockGetArgument(entryBlock, it.index()); + scope.bindValue(it.value(), value); + } + + // Walk body nodes. + for (auto *node : graph->nodes()) { + NodeImporter importer{ + node, *this, funcBuilder.get(), entryBlock, /*ip=*/{nullptr}, &scope}; + importer.importNode(); + } + + // Map the output node to a return. + auto *outputNode = graph->return_node(); + NodeImporter returnImporter{outputNode, *this, + funcBuilder.get(), entryBlock, + /*ip=*/{nullptr}, &scope}; + returnImporter.importReturnOp(); +} + +MlirLocation GraphImporter::extractCallstackLoc(torch::jit::Node *node, + bool useDefault) { + auto flc = node->sourceRange().file_line_col(); + if (flc) { + const std::string &file = std::get<0>(*flc); + int line = std::get<1>(*flc); + int col = std::get<2>(*flc); + return mlirLocationFileLineColGet(context(), file.c_str(), line, col); + } + + return useDefault ? defaultLoc : MlirLocation{nullptr}; +} diff --git a/frontends/pytorch/csrc/builder/graph_importer.h b/frontends/pytorch/csrc/builder/graph_importer.h new file mode 100644 index 000000000..02e738369 --- /dev/null +++ b/frontends/pytorch/csrc/builder/graph_importer.h @@ -0,0 +1,91 @@ +//===- graph_importer.h -----------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See frontends/pytorch/LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H +#define NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H + +#include + +#include "../pybind.h" +#include "func_builder.h" + +#include "mlir-c/IR.h" + +#include +#include + +namespace torch_mlir { + +/// Main entry-point for importing torch::jit::Graph instances (and structures +/// surrounding them such as modules and methods). +/// +/// In torch terminology, a Graph is a function. Later in the compiler, we may +/// specialize multiple versions of it. +/// +/// Since graph functions typically have enough annotations for the most +/// generic form of every type (i.e. Tensor, List, etc), and since we often +/// want to multi-version specific specializations, we take the approach of +/// generating a '$generic' suffixed function at that level and then generate +/// the actual named function with using a 'numpy.generic_call' op to invoke +/// the generic function with metadata controlling how it is legal to +/// specialize. This leaves the process of inlining and expanding the +/// specializations to compiler passes. +class GraphImporter : public std::enable_shared_from_this { +public: + /// Options for mapping Graph concepts to MLIR. In addition to things like + /// names and type mappings, this includes various policy options such as + /// when to import globals as constants vs shared arrays, etc. + struct MlirMappingOptions { + MlirContext context; + llvm::Optional genericFuncName; + llvm::Optional funcName; + TypeMapper &typeMapper; + FuncBuilder::Inserter &inserter; + }; + /// Construct an importer. + GraphImporter(std::shared_ptr graph, + MlirMappingOptions mappingOptions); + + /// Helper to create a graph importer from a traced/scripted python function. + /// If the funcName of the mapping options is not set, it is set from the + /// function name. It is the responsibility of the caller to ensure that the + /// funcObj and associated graph outlives this instance. + static std::shared_ptr + forPythonJitFunc(torch::jit::Function *function, + MlirMappingOptions mappingOptions); + + /// Initialize for import. This is separate from the constructor purely for + /// ergonomics and must be called post-construction. Initialization activities + /// that throw go here. + void initialize(); + + /// Imports the generic function into the module. + void importGenericFunc(); + +private: + class NodeScope; + class NodeImporter; + + MlirContext context() { return mappingOptions.context; } + TypeMapper &type_mapper() { return mappingOptions.typeMapper; } + MlirLocation extractCallstackLoc(torch::jit::Node *node, + bool useDefault = true); + std::shared_ptr graph; + MlirMappingOptions mappingOptions; + + /// Default function location, to be used when a more specific is not + /// available. + MlirLocation defaultLoc; + + /// Argument and return types for the generic func. + llvm::SmallVector genericFuncArgTypes; + llvm::SmallVector genericFuncReturnTypes; +}; + +} // namespace torch_mlir + +#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H diff --git a/frontends/pytorch/csrc/builder/module_builder.cpp b/frontends/pytorch/csrc/builder/module_builder.cpp index f3dbe3afc..2114b379b 100644 --- a/frontends/pytorch/csrc/builder/module_builder.cpp +++ b/frontends/pytorch/csrc/builder/module_builder.cpp @@ -7,6 +7,8 @@ #include "module_builder.h" +#include "graph_importer.h" + #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Registration.h" #include "mlir-c/StandardAttributes.h" @@ -78,11 +80,9 @@ ModuleBuilder::startCaptureFunction(std::string &name, } // TODO: Extract a traceback and use in place of unknownLoc. + auto inserter = createInserter(); auto funcBuilder = - FuncBuilder::createFunction(context, unknownLoc, name, inputTypes); - mlirBlockInsertOwnedOperationBefore(getBodyBlock(), terminator, - funcBuilder->getFuncOp()); - + FuncBuilder::createFunction(inserter, unknownLoc, name, inputTypes); // Map block arguments. MlirBlock entryBlock = funcBuilder->getEntryBlock(); assert(mlirBlockGetNumArguments(entryBlock) == @@ -95,6 +95,28 @@ ModuleBuilder::startCaptureFunction(std::string &name, return std::make_shared(typeMapper, std::move(funcBuilder)); } +void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { + auto inserter = createInserter(); + GraphImporter::MlirMappingOptions mappingOptions{ + context, + llvm::None, // genericFuncName (default to auto) + llvm::None, // funcName (default to auto) + typeMapper, inserter, + }; + auto graphImporter = GraphImporter::forPythonJitFunc( + function.function_, std::move(mappingOptions)); + graphImporter->initialize(); + graphImporter->importGenericFunc(); +} + +FuncBuilder::Inserter ModuleBuilder::createInserter() { + MlirBlock block = getBodyBlock(); + MlirOperation terminator = this->terminator; + return [=](MlirOperation op) { + mlirBlockInsertOwnedOperationBefore(block, terminator, op); + }; +} + MlirBlock ModuleBuilder::getBodyBlock() { MlirOperation moduleOp = mlirModuleGetOperation(module); return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0)); @@ -106,5 +128,6 @@ void ModuleBuilder::bind(py::module &m) { .def_property_readonly("context", &ModuleBuilder::getContextObj) .def_property_readonly("module", &ModuleBuilder::getModuleObj) .def("capture_function", &ModuleBuilder::startCaptureFunction, - py::keep_alive<0, 1>()); + py::keep_alive<0, 1>()) + .def("import_function", &ModuleBuilder::importFunction); } diff --git a/frontends/pytorch/csrc/builder/module_builder.h b/frontends/pytorch/csrc/builder/module_builder.h index 1b6d927d9..56bbcf935 100644 --- a/frontends/pytorch/csrc/builder/module_builder.h +++ b/frontends/pytorch/csrc/builder/module_builder.h @@ -16,6 +16,8 @@ #include "llvm/ADT/SmallVector.h" #include +#include +#include namespace torch_mlir { @@ -32,11 +34,16 @@ public: pybind11::object getModuleObj() { return moduleObj; } // Starts a device-capture based function. - // TODO: Add inputs. std::shared_ptr startCaptureFunction(std::string &name, std::vector args); + // Imports a traced function. Note that the python type + // torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr. + // Just a bit of naming cruft. + void importFunction(torch::jit::StrongFunctionPtr function); + private: + FuncBuilder::Inserter createInserter(); MlirBlock getBodyBlock(); // Capture references to the python-owned context and module. Ownership diff --git a/frontends/pytorch/csrc/builder/python_bindings.cpp b/frontends/pytorch/csrc/builder/python_bindings.cpp index 7aa2729a0..1d994ab15 100644 --- a/frontends/pytorch/csrc/builder/python_bindings.cpp +++ b/frontends/pytorch/csrc/builder/python_bindings.cpp @@ -140,4 +140,3 @@ void torch_mlir::InitBuilderBindings(py::module &m) { ModuleBuilder::bind(m); } - diff --git a/frontends/pytorch/csrc/pybind.h b/frontends/pytorch/csrc/pybind.h index 1d8143171..00940773e 100644 --- a/frontends/pytorch/csrc/pybind.h +++ b/frontends/pytorch/csrc/pybind.h @@ -14,4 +14,15 @@ #include +namespace torch_mlir { + +/// Thrown on failure when details are in MLIR emitted diagnostics. +class mlir_diagnostic_emitted : public std::runtime_error { +public: + mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {} + mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {} +}; + +} // namespace torch_mlir + #endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_PYBIND_H diff --git a/frontends/pytorch/test/acap_export/test_export_add3.py b/frontends/pytorch/test/acap_export/test_export_add3.py index 6aa471430..22500931b 100644 --- a/frontends/pytorch/test/acap_export/test_export_add3.py +++ b/frontends/pytorch/test/acap_export/test_export_add3.py @@ -2,6 +2,8 @@ # This file is licensed under a pytorch-style license # See frontends/pytorch/LICENSE for license information. + + import torch import torch_mlir diff --git a/frontends/pytorch/test/graph_export/test_script_add3.py b/frontends/pytorch/test/graph_export/test_script_add3.py new file mode 100644 index 000000000..c490df9d3 --- /dev/null +++ b/frontends/pytorch/test/graph_export/test_script_add3.py @@ -0,0 +1,25 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch_mlir + +# RUN: %PYTHON %s | npcomp-opt | FileCheck %s + +@torch.jit.script +def add3(t0, t1, t2): + return t0 + t1 + t2 + +mb = torch_mlir.ModuleBuilder() +mb.import_function(add3) + +# Verify without debug info. +# CHECK-LABEL: func @add3$generic +# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> { +# CHECK: %[[C1:.*]] = constant 1 : i64 +# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} +# CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} +# CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype> +mb.module.operation.print() +print() diff --git a/frontends/pytorch/test/graph_export/test_script_debug_info.py b/frontends/pytorch/test/graph_export/test_script_debug_info.py new file mode 100644 index 000000000..f9a4e2a4e --- /dev/null +++ b/frontends/pytorch/test/graph_export/test_script_debug_info.py @@ -0,0 +1,26 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch_mlir + +# RUN: %PYTHON %s | FileCheck %s + +@torch.jit.script +def add3(t0, t1, t2): + intermediate = t0 + t1 + final = intermediate + t2 + return final + +mb = torch_mlir.ModuleBuilder() +mb.import_function(add3) + +# Verify again with debug info present. Just checking that it makes it in there. +# CHECK-LABEL: func @add3$generic +# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py +# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py +# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py +# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py +mb.module.operation.print(enable_debug_info=True) +print() diff --git a/include/npcomp-c/Types.h b/include/npcomp-c/Types.h index 3508e3509..a89c740da 100644 --- a/include/npcomp-c/Types.h +++ b/include/npcomp-c/Types.h @@ -64,6 +64,9 @@ MlirType npcompListTypeGet(MlirContext context); /** Checks whether the given type is an NdArray type. */ int npcompTypeIsANdArray(MlirType t); +/** Gets a numpy.NdArray type that is unranked. */ +MlirType npcompNdArrayTypeGetUnranked(MlirType elementType); + /** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are * unknown. */ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, diff --git a/lib/CAPI/Types.cpp b/lib/CAPI/Types.cpp index fbacfb4a7..357b44f94 100644 --- a/lib/CAPI/Types.cpp +++ b/lib/CAPI/Types.cpp @@ -70,6 +70,10 @@ int npcompTypeIsANdArray(MlirType t) { return unwrap(t).isa(); } +MlirType npcompNdArrayTypeGetUnranked(MlirType elementType) { + return wrap(Numpy::NdArrayType::get(unwrap(elementType))); +} + MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, MlirType elementType) { llvm::ArrayRef shapeArray(shape, rank);