//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/RefBackend/Runtime/UserAPI.h" #include #include #include #include #include "CompilerDataStructures.h" using namespace refbackrt; //===----------------------------------------------------------------------===// // Memref descriptors for interacting with MLIR codegenerated code. //===----------------------------------------------------------------------===// namespace { // These definitions are based on the ones in // `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in // sync. // // Those definitions are flawed though because they are overly templated. struct MemrefDescriptor { void *allocatedPtr; void *dataPtr; std::int64_t offset; // Tail-allocated int64_t sizes followed by strides. MutableArrayRef getSizes(int assumedRank) { auto *tail = reinterpret_cast(this + 1); return MutableArrayRef(tail, assumedRank); } MutableArrayRef getStrides(int assumedRank) { auto *tail = reinterpret_cast(this + 1); return MutableArrayRef(tail + assumedRank, assumedRank); } // Returns a malloc-allocated MemrefDescriptor with the specified extents and // default striding. static MemrefDescriptor *create(ArrayRef extents, void *data); // Returns the number of elements in this MemrefDescriptor, assuming this // descriptor has rank `assumedRank`. std::int32_t getNumElements(int assumedRank) { if (assumedRank == 0) return 1; return getSizes(assumedRank)[0] * getStrides(assumedRank)[0]; } }; } // namespace namespace { struct UnrankedMemref { int64_t rank; MemrefDescriptor *descriptor; }; } // namespace MemrefDescriptor *MemrefDescriptor::create(ArrayRef extents, void *data) { auto rank = extents.size(); auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank; auto *descriptor = static_cast(std::malloc(allocSize)); descriptor->allocatedPtr = data; descriptor->dataPtr = data; descriptor->offset = 0; // Iterate in reverse, copying the dimension sizes (i.e. extents) and // calculating the strides for a standard dense layout. std::int64_t stride = 1; for (int i = 0, e = rank; i < e; i++) { auto revIdx = e - i - 1; descriptor->getSizes(rank)[revIdx] = extents[revIdx]; descriptor->getStrides(rank)[revIdx] = stride; stride *= extents[revIdx]; } return descriptor; } static UnrankedMemref convertRefbackrtTensorToUnrankedMemref(Tensor *tensor) { auto byteSize = tensor->getDataByteSize(); void *data = std::malloc(byteSize); std::memcpy(data, tensor->getData(), byteSize); auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data); return UnrankedMemref{tensor->getRank(), descriptor}; } static Tensor *convertUnrankedMemrefToRefbackrtTensor( std::int64_t rank, MemrefDescriptor *descriptor, ElementType elementType) { // Launder from std::int64_t to std::int32_t. auto extents64 = descriptor->getSizes(rank); constexpr int kMaxRank = 20; std::array extents32Buf; for (int i = 0, e = extents64.size(); i < e; i++) extents32Buf[i] = extents64[i]; return Tensor::createRaw(ArrayRef(extents32Buf.data(), rank), elementType, descriptor->dataPtr); } //===----------------------------------------------------------------------===// // Tensor //===----------------------------------------------------------------------===// static std::int32_t totalElements(ArrayRef extents) { std::int32_t ret = 1; for (int i = 0, e = extents.size(); i < e; i++) { ret *= extents[i]; } return ret; } std::int32_t refbackrt::getElementTypeByteSize(ElementType type) { switch (type) { case ElementType::F32: return 4; } } Ref Tensor::create(ArrayRef extents, ElementType type, void *data) { return Ref(createRaw(extents, type, data)); } Tensor *Tensor::createRaw(ArrayRef extents, ElementType type, void *data) { auto *tensor = static_cast( std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t))); tensor->refCount = 0; tensor->elementType = type; tensor->rank = extents.size(); auto byteSize = getElementTypeByteSize(type) * totalElements(extents); // TODO: Align the buffer. tensor->allocatedPtr = std::malloc(byteSize); tensor->data = tensor->allocatedPtr; std::memcpy(tensor->data, data, byteSize); for (int i = 0, e = extents.size(); i < e; i++) tensor->getMutableExtents()[i] = extents[i]; return tensor; } std::int32_t Tensor::getDataByteSize() const { return getElementTypeByteSize(getElementType()) * totalElements(getExtents()); } //===----------------------------------------------------------------------===// // Module metadata descriptors. //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Module operations. //===----------------------------------------------------------------------===// template static void *ToVoidPtr(T *ptr) { return const_cast(static_cast(ptr)); } static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor, StringRef name) { for (int i = 0, e = moduleDescriptor->numFuncDescriptors; i < e; i++) { auto &functionDescriptor = moduleDescriptor->functionDescriptors[i]; if (StringRef(functionDescriptor.name, functionDescriptor.nameLen) == name) { return &functionDescriptor; } } return nullptr; } void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName, ArrayRef inputs, MutableArrayRef outputs) { auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName); assert(descriptor && "unknown function name"); assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity"); assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity"); // We haven't committed to using "vector" in this runtime code, so use // a fixed-sized array. std::array inputUnrankedMemrefs; std::array outputUnrankedMemrefs; std::array packedInputs; std::array packedOutputs; // Deepcopy the refbackrt::Tensor's into UnrankedMemref's. // TODO: Avoid the deep copy. It makes the later lifetime management code // more complex though (and maybe impossible given the current abstractions). for (int i = 0, e = inputs.size(); i < e; i++) { inputUnrankedMemrefs[i] = convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get()); } // Create a type-erased list of "packed inputs" to pass to the // LLVM/C ABI wrapper function. Each packedInput pointer corresponds to // one LLVM/C ABI argument to the underlying function. // // The ABI lowering on StandardToLLVM conversion side will // "explode" the unranked memref descriptors on the underlying function // into separate arguments for the rank and pointer-to-descriptor. for (int i = 0, e = inputs.size(); i < e; i++) { packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank); packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor); } // Create a type-erased list of "packed output" to pass to the // LLVM/C ABI wrapper function. // // Due to how StandardToLLVM lowering works, each packedOutput pointer // corresponds to a single UnrankedMemref (not "exploded"). for (int i = 0, e = outputs.size(); i < e; i++) { packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]); } // Actually invoke the function! descriptor->functionPtr(packedInputs.data(), packedOutputs.data()); // Copy out the result data into refbackrt::Tensor's. // TODO: Avoid needing to make a deep copy. for (int i = 0, e = outputs.size(); i < e; i++) { // TODO: Have compiler emit the element type in the metadata. auto elementType = ElementType::F32; Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor( outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor, elementType); outputs[i] = RtValue(Ref(tensor)); } // Now, we just need to free all the UnrankedMemref's that we created. // This is complicated by the fact that multiple input/output UnrankedMemref's // can end up with the same backing buffer (`allocatedPtr`), and we need // to avoid double-freeing. // Output buffers might alias any other input or output buffer. // Input buffers are guaranteed to not alias each other. // Free the output buffers. for (int i = 0, e = outputs.size(); i < e; i++) { void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr; // Multiple returned memrefs can point into the same underlying // malloc allocation. Do a linear scan to see if any of the previously // deallocated buffers already freed this pointer. bool bufferNeedsFreeing = true; for (int j = 0; j < i; j++) { if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr) bufferNeedsFreeing = false; } if (!bufferNeedsFreeing) std::free(allocatedPtr); } // Free the input buffers. for (int i = 0, e = inputs.size(); i < e; i++) { void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr; bool bufferNeedsFreeing = true; for (int j = 0, je = outputs.size(); j < je; j++) { if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr) bufferNeedsFreeing = false; } // HACK: The returned memref can point into statically allocated memory that // we can't pass to `free`, such as the result of lowering a tensor-valued // `std.constant` to `std.global_memref`. The LLVM lowering of // std.global_memref sets the allocated pointer to the magic value // 0xDEADBEEF, which we sniff for here. This is yet another strong signal // that memref is really not the right abstraction for ABI's. if (reinterpret_cast(allocatedPtr) == 0xDEADBEEF) bufferNeedsFreeing = false; if (!bufferNeedsFreeing) std::free(allocatedPtr); } // Free the output descriptors. for (int i = 0, e = outputs.size(); i < e; i++) { // The LLVM lowering guarantees that each returned unranked memref // descriptor is separately malloc'ed, so no need to do anything special // like we had to do for the allocatedPtr's. std::free(outputUnrankedMemrefs[i].descriptor); } // Free the input descriptors. for (int i = 0, e = inputs.size(); i < e; i++) { std::free(inputUnrankedMemrefs[i].descriptor); } } LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor, StringRef functionName, FunctionMetadata &outMetadata) { auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName); if (!descriptor) return failure(); outMetadata.numInputs = descriptor->numInputs; outMetadata.numOutputs = descriptor->numOutputs; return success(); }