torch-mlir/lib/RefBackend/Runtime/Runtime.cpp

298 lines
12 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/RefBackend/Runtime/UserAPI.h"
#include "llvm/Support/ErrorHandling.h"
#include <array>
#include <cassert>
#include <cstdint>
#include <cstring>
#include "CompilerDataStructures.h"
using namespace refbackrt;
//===----------------------------------------------------------------------===//
// Memref descriptors for interacting with MLIR codegenerated code.
//===----------------------------------------------------------------------===//
namespace {
// These definitions are based on the ones in
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
// sync.
//
// Those definitions are flawed though because they are overly templated.
struct MemrefDescriptor {
void *allocatedPtr;
void *dataPtr;
std::int64_t offset;
// Tail-allocated int64_t sizes followed by strides.
MutableArrayRef<std::int64_t> getSizes(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail, assumedRank);
}
MutableArrayRef<std::int64_t> getStrides(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail + assumedRank, assumedRank);
}
// Returns a malloc-allocated MemrefDescriptor with the specified extents and
// default striding.
static MemrefDescriptor *create(ArrayRef<std::int32_t> extents, void *data);
// Returns the number of elements in this MemrefDescriptor, assuming this
// descriptor has rank `assumedRank`.
std::int32_t getNumElements(int assumedRank) {
if (assumedRank == 0)
return 1;
return getSizes(assumedRank)[0] * getStrides(assumedRank)[0];
}
};
} // namespace
namespace {
struct UnrankedMemref {
int64_t rank;
MemrefDescriptor *descriptor;
};
} // namespace
MemrefDescriptor *MemrefDescriptor::create(ArrayRef<std::int32_t> extents,
void *data) {
auto rank = extents.size();
auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank;
auto *descriptor = static_cast<MemrefDescriptor *>(std::malloc(allocSize));
descriptor->allocatedPtr = data;
descriptor->dataPtr = data;
descriptor->offset = 0;
// Iterate in reverse, copying the dimension sizes (i.e. extents) and
// calculating the strides for a standard dense layout.
std::int64_t stride = 1;
for (int i = 0, e = rank; i < e; i++) {
auto revIdx = e - i - 1;
descriptor->getSizes(rank)[revIdx] = extents[revIdx];
descriptor->getStrides(rank)[revIdx] = stride;
stride *= extents[revIdx];
}
return descriptor;
}
static UnrankedMemref convertRefbackrtTensorToUnrankedMemref(Tensor *tensor) {
auto byteSize = tensor->getDataByteSize();
void *data = std::malloc(byteSize);
std::memcpy(data, tensor->getData(), byteSize);
auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data);
return UnrankedMemref{tensor->getRank(), descriptor};
}
static Tensor *convertUnrankedMemrefToRefbackrtTensor(
std::int64_t rank, MemrefDescriptor *descriptor, ElementType elementType) {
// Launder from std::int64_t to std::int32_t.
auto extents64 = descriptor->getSizes(rank);
constexpr int kMaxRank = 20;
std::array<std::int32_t, kMaxRank> extents32Buf;
for (int i = 0, e = extents64.size(); i < e; i++)
extents32Buf[i] = extents64[i];
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
elementType, descriptor->dataPtr);
}
//===----------------------------------------------------------------------===//
// Tensor
//===----------------------------------------------------------------------===//
static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
std::int32_t ret = 1;
for (int i = 0, e = extents.size(); i < e; i++) {
ret *= extents[i];
}
return ret;
}
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
switch (type) {
case ElementType::F32:
return 4;
}
llvm_unreachable("unsupported dtype");
}
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
void *data) {
return Ref<Tensor>(createRaw(extents, type, data));
}
Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
void *data) {
auto *tensor = static_cast<Tensor *>(
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
tensor->refCount = 0;
tensor->elementType = type;
tensor->rank = extents.size();
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
// TODO: Align the buffer.
tensor->allocatedPtr = std::malloc(byteSize);
tensor->data = tensor->allocatedPtr;
std::memcpy(tensor->data, data, byteSize);
for (int i = 0, e = extents.size(); i < e; i++)
tensor->getMutableExtents()[i] = extents[i];
return tensor;
}
std::int32_t Tensor::getDataByteSize() const {
return getElementTypeByteSize(getElementType()) * totalElements(getExtents());
}
//===----------------------------------------------------------------------===//
// Module metadata descriptors.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Module operations.
//===----------------------------------------------------------------------===//
template <typename T> static void *ToVoidPtr(T *ptr) {
return const_cast<void *>(static_cast<const void *>(ptr));
}
static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
StringRef name) {
for (int i = 0, e = moduleDescriptor->numFuncDescriptors; i < e; i++) {
auto &functionDescriptor = moduleDescriptor->functionDescriptors[i];
if (StringRef(functionDescriptor.name, functionDescriptor.nameLen) ==
name) {
return &functionDescriptor;
}
}
return nullptr;
}
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
StringRef functionName, ArrayRef<RtValue> inputs,
MutableArrayRef<RtValue> outputs) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
assert(descriptor && "unknown function name");
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
// We haven't committed to using "vector" in this runtime code, so use
// a fixed-sized array.
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
std::array<void *, kMaxArity * 2> packedInputs;
std::array<void *, kMaxArity> packedOutputs;
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
// TODO: Avoid the deep copy. It makes the later lifetime management code
// more complex though (and maybe impossible given the current abstractions).
for (int i = 0, e = inputs.size(); i < e; i++) {
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
}
// Create a type-erased list of "packed inputs" to pass to the
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
// one LLVM/C ABI argument to the underlying function.
//
// The ABI lowering on StandardToLLVM conversion side will
// "explode" the unranked memref descriptors on the underlying function
// into separate arguments for the rank and pointer-to-descriptor.
for (int i = 0, e = inputs.size(); i < e; i++) {
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
}
// Create a type-erased list of "packed output" to pass to the
// LLVM/C ABI wrapper function.
//
// Due to how StandardToLLVM lowering works, each packedOutput pointer
// corresponds to a single UnrankedMemref (not "exploded").
for (int i = 0, e = outputs.size(); i < e; i++) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
}
// Actually invoke the function!
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
// Copy out the result data into refbackrt::Tensor's.
// TODO: Avoid needing to make a deep copy.
for (int i = 0, e = outputs.size(); i < e; i++) {
// TODO: Have compiler emit the element type in the metadata.
auto elementType = ElementType::F32;
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
elementType);
outputs[i] = RtValue(Ref<Tensor>(tensor));
}
// Now, we just need to free all the UnrankedMemref's that we created.
// This is complicated by the fact that multiple input/output UnrankedMemref's
// can end up with the same backing buffer (`allocatedPtr`), and we need
// to avoid double-freeing.
// Output buffers might alias any other input or output buffer.
// Input buffers are guaranteed to not alias each other.
// Free the output buffers.
for (int i = 0, e = outputs.size(); i < e; i++) {
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
// Multiple returned memrefs can point into the same underlying
// malloc allocation. Do a linear scan to see if any of the previously
// deallocated buffers already freed this pointer.
bool bufferNeedsFreeing = true;
for (int j = 0; j < i; j++) {
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
// Free the input buffers.
for (int i = 0, e = inputs.size(); i < e; i++) {
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
bool bufferNeedsFreeing = true;
for (int j = 0, je = outputs.size(); j < je; j++) {
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
// HACK: The returned memref can point into statically allocated memory that
// we can't pass to `free`, such as the result of lowering a tensor-valued
// `std.constant` to `std.global_memref`. The LLVM lowering of
// std.global_memref sets the allocated pointer to the magic value
// 0xDEADBEEF, which we sniff for here. This is yet another strong signal
// that memref is really not the right abstraction for ABI's.
if (reinterpret_cast<std::intptr_t>(allocatedPtr) == 0xDEADBEEF)
bufferNeedsFreeing = false;
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
// Free the output descriptors.
for (int i = 0, e = outputs.size(); i < e; i++) {
// The LLVM lowering guarantees that each returned unranked memref
// descriptor is separately malloc'ed, so no need to do anything special
// like we had to do for the allocatedPtr's.
std::free(outputUnrankedMemrefs[i].descriptor);
}
// Free the input descriptors.
for (int i = 0, e = inputs.size(); i < e; i++) {
std::free(inputUnrankedMemrefs[i].descriptor);
}
}
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
StringRef functionName,
FunctionMetadata &outMetadata) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
if (!descriptor)
return failure();
outMetadata.numInputs = descriptor->numInputs;
outMetadata.numOutputs = descriptor->numOutputs;
return success();
}