mirror of https://github.com/llvm/torch-mlir
516 lines
18 KiB
C++
516 lines
18 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "npcomp/RefBackend/Runtime/UserAPI.h"
|
|
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
#include <array>
|
|
#include <cassert>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
|
|
#include "CompilerDataStructures.h"
|
|
|
|
using namespace refbackrt;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Memref descriptors for interacting with MLIR codegenerated code.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
// These definitions are based on the ones in
|
|
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
|
|
// sync.
|
|
//
|
|
// Those definitions are flawed though because they are overly templated.
|
|
struct MemrefDescriptor {
|
|
void *allocatedPtr;
|
|
void *dataPtr;
|
|
std::int64_t offset;
|
|
// Tail-allocated int64_t sizes followed by strides.
|
|
MutableArrayRef<std::int64_t> getSizes(int assumedRank) {
|
|
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
|
|
return MutableArrayRef<std::int64_t>(tail, assumedRank);
|
|
}
|
|
MutableArrayRef<std::int64_t> getStrides(int assumedRank) {
|
|
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
|
|
return MutableArrayRef<std::int64_t>(tail + assumedRank, assumedRank);
|
|
}
|
|
|
|
// Returns a malloc-allocated MemrefDescriptor with the specified extents and
|
|
// default striding.
|
|
static MemrefDescriptor *create(ArrayRef<std::int32_t> extents, void *data);
|
|
|
|
// Returns the number of elements in this MemrefDescriptor, assuming this
|
|
// descriptor has rank `assumedRank`.
|
|
std::int32_t getNumElements(int assumedRank) {
|
|
if (assumedRank == 0)
|
|
return 1;
|
|
return getSizes(assumedRank)[0] * getStrides(assumedRank)[0];
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
struct UnrankedMemref {
|
|
int64_t rank;
|
|
MemrefDescriptor *descriptor;
|
|
};
|
|
} // namespace
|
|
|
|
MemrefDescriptor *MemrefDescriptor::create(ArrayRef<std::int32_t> extents,
|
|
void *data) {
|
|
auto rank = extents.size();
|
|
auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank;
|
|
auto *descriptor = static_cast<MemrefDescriptor *>(std::malloc(allocSize));
|
|
descriptor->allocatedPtr = data;
|
|
descriptor->dataPtr = data;
|
|
descriptor->offset = 0;
|
|
// Iterate in reverse, copying the dimension sizes (i.e. extents) and
|
|
// calculating the strides for a standard dense layout.
|
|
std::int64_t stride = 1;
|
|
for (int i = 0, e = rank; i < e; i++) {
|
|
auto revIdx = e - i - 1;
|
|
descriptor->getSizes(rank)[revIdx] = extents[revIdx];
|
|
descriptor->getStrides(rank)[revIdx] = stride;
|
|
stride *= extents[revIdx];
|
|
}
|
|
return descriptor;
|
|
}
|
|
|
|
static UnrankedMemref convertRefbackrtTensorToUnrankedMemref(Tensor *tensor) {
|
|
auto byteSize = tensor->getDataByteSize();
|
|
void *data = std::malloc(byteSize);
|
|
std::memcpy(data, tensor->getData(), byteSize);
|
|
auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data);
|
|
return UnrankedMemref{tensor->getRank(), descriptor};
|
|
}
|
|
|
|
static Tensor *convertUnrankedMemrefToRefbackrtTensor(
|
|
std::int64_t rank, MemrefDescriptor *descriptor, ElementType elementType) {
|
|
// Launder from std::int64_t to std::int32_t.
|
|
auto extents64 = descriptor->getSizes(rank);
|
|
constexpr int kMaxRank = 20;
|
|
std::array<std::int32_t, kMaxRank> extents32Buf;
|
|
for (int i = 0, e = extents64.size(); i < e; i++)
|
|
extents32Buf[i] = extents64[i];
|
|
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
|
|
elementType, descriptor->dataPtr);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Tensor
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
|
|
std::int32_t ret = 1;
|
|
for (int i = 0, e = extents.size(); i < e; i++) {
|
|
ret *= extents[i];
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
|
|
switch (type) {
|
|
case ElementType::NONE:
|
|
return 0;
|
|
case ElementType::F32:
|
|
return 4;
|
|
}
|
|
llvm_unreachable("unsupported dtype");
|
|
}
|
|
|
|
StringRef refbackrt::getElementTypeAsStringRef(ElementType type) {
|
|
switch (type) {
|
|
case ElementType::NONE:
|
|
return "NONE";
|
|
case ElementType::F32:
|
|
return "F32";
|
|
}
|
|
llvm_unreachable("unsupported element type string");
|
|
}
|
|
|
|
StringRef refbackrt::getArgTypeAsStringRef(ArgType type) {
|
|
switch (type) {
|
|
case ArgType::kNone:
|
|
return "kNone";
|
|
case ArgType::kTensor:
|
|
return "kTensor";
|
|
case ArgType::kF32:
|
|
return "kF32";
|
|
case ArgType::kF64:
|
|
return "kF64";
|
|
}
|
|
llvm_unreachable("unsupported arg type string");
|
|
}
|
|
|
|
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
|
|
void *data) {
|
|
return Ref<Tensor>(createRaw(extents, type, data));
|
|
}
|
|
|
|
Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
|
|
void *data) {
|
|
auto *tensor = static_cast<Tensor *>(
|
|
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
|
|
|
|
tensor->refCount = 0;
|
|
tensor->elementType = type;
|
|
tensor->rank = extents.size();
|
|
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
|
|
// TODO: Align the buffer.
|
|
tensor->allocatedPtr = std::malloc(byteSize);
|
|
tensor->data = tensor->allocatedPtr;
|
|
std::memcpy(tensor->data, data, byteSize);
|
|
for (int i = 0, e = extents.size(); i < e; i++)
|
|
tensor->getMutableExtents()[i] = extents[i];
|
|
return tensor;
|
|
}
|
|
|
|
std::int32_t Tensor::getDataByteSize() const {
|
|
return getElementTypeByteSize(getElementType()) * totalElements(getExtents());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Module metadata descriptors.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Module operations.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename T> static void *ToVoidPtr(T *ptr) {
|
|
return const_cast<void *>(static_cast<const void *>(ptr));
|
|
}
|
|
static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
|
|
StringRef name) {
|
|
for (int i = 0, e = moduleDescriptor->numFuncDescriptors; i < e; i++) {
|
|
auto &functionDescriptor = moduleDescriptor->functionDescriptors[i];
|
|
if (StringRef(functionDescriptor.name, functionDescriptor.nameLen) ==
|
|
name) {
|
|
return &functionDescriptor;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|
StringRef functionName, ArrayRef<RtValue> inputs,
|
|
MutableArrayRef<RtValue> outputs) {
|
|
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
|
assert(descriptor && "unknown function name");
|
|
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
|
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
|
|
|
|
// We haven't committed to using "vector" in this runtime code, so use
|
|
// a fixed-sized array.
|
|
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
|
|
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
|
|
std::array<void *, kMaxArity * 2> packedInputs;
|
|
std::array<void *, kMaxArity> packedOutputs;
|
|
|
|
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
|
|
// TODO: Avoid the deep copy. It makes the later lifetime management code
|
|
// more complex though (and maybe impossible given the current abstractions).
|
|
//
|
|
// Create a type-erased list of "packed inputs" to pass to the
|
|
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
|
// one LLVM/C ABI argument to the underlying function.
|
|
//
|
|
// The ABI lowering on StandardToLLVM conversion side will
|
|
// "explode" the unranked memref descriptors on the underlying function
|
|
// into separate arguments for the rank and pointer-to-descriptor.
|
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
|
auto idx = 2 * i;
|
|
if (inputs[i].isTensor()) {
|
|
inputUnrankedMemrefs[i] =
|
|
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
|
packedInputs[idx] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
|
packedInputs[idx + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
|
} else if (inputs[i].isScalar()) {
|
|
packedInputs[idx] = ToVoidPtr(&inputs[i]);
|
|
} else {
|
|
assert(false && "unsupported input RtValue type");
|
|
}
|
|
}
|
|
|
|
// Create a type-erased list of "packed output" to pass to the
|
|
// LLVM/C ABI wrapper function.
|
|
//
|
|
// Due to how StandardToLLVM lowering works, each packedOutput pointer
|
|
// corresponds to a single UnrankedMemref (not "exploded").
|
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
|
if (outputs[i].isTensor()) {
|
|
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
|
} else if (outputs[i].isScalar()) {
|
|
packedOutputs[i] = ToVoidPtr(&outputs[i]);
|
|
}
|
|
}
|
|
|
|
// Actually invoke the function!
|
|
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
|
|
|
|
// Copy out the result data into refbackrt::Tensor's.
|
|
// TODO: Avoid needing to make a deep copy.
|
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
|
// TODO: Have compiler emit the element type in the metadata.
|
|
if (outputs[i].isTensor()) {
|
|
auto elementType = ElementType::F32;
|
|
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
|
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
|
elementType);
|
|
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
|
} else if (outputs[i].isFloat()) {
|
|
outputs[i] = RtValue(*(reinterpret_cast<float *>(packedOutputs[i])));
|
|
}
|
|
}
|
|
|
|
// Now, we just need to free all the UnrankedMemref's that we created.
|
|
// This is complicated by the fact that multiple input/output UnrankedMemref's
|
|
// can end up with the same backing buffer (`allocatedPtr`), and we need
|
|
// to avoid double-freeing.
|
|
// Output buffers might alias any other input or output buffer.
|
|
// Input buffers are guaranteed to not alias each other.
|
|
|
|
// Free the output buffers.
|
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
|
if (outputs[i].isRef()) {
|
|
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
|
// Multiple returned memrefs can point into the same underlying
|
|
// malloc allocation. Do a linear scan to see if any of the previously
|
|
// deallocated buffers already freed this pointer.
|
|
bool bufferNeedsFreeing = true;
|
|
for (int j = 0; j < i; j++) {
|
|
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
|
bufferNeedsFreeing = false;
|
|
}
|
|
if (!bufferNeedsFreeing)
|
|
std::free(allocatedPtr);
|
|
}
|
|
}
|
|
|
|
// Free the input buffers.
|
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
|
if (!inputs[i].isRef())
|
|
continue;
|
|
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
|
bool bufferNeedsFreeing = true;
|
|
for (int j = 0, je = outputs.size(); j < je; j++) {
|
|
if (!outputs[j].isRef())
|
|
continue;
|
|
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
|
bufferNeedsFreeing = false;
|
|
}
|
|
// HACK: The returned memref can point into statically allocated memory that
|
|
// we can't pass to `free`, such as the result of lowering a tensor-valued
|
|
// `std.constant` to `std.global_memref`. The LLVM lowering of
|
|
// std.global_memref sets the allocated pointer to the magic value
|
|
// 0xDEADBEEF, which we sniff for here. This is yet another strong signal
|
|
// that memref is really not the right abstraction for ABI's.
|
|
if (reinterpret_cast<std::intptr_t>(allocatedPtr) == 0xDEADBEEF)
|
|
bufferNeedsFreeing = false;
|
|
if (!bufferNeedsFreeing)
|
|
std::free(allocatedPtr);
|
|
}
|
|
|
|
// Free the output descriptors.
|
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
|
if (!outputs[i].isRef())
|
|
continue;
|
|
// The LLVM lowering guarantees that each returned unranked memref
|
|
// descriptor is separately malloc'ed, so no need to do anything special
|
|
// like we had to do for the allocatedPtr's.
|
|
std::free(outputUnrankedMemrefs[i].descriptor);
|
|
}
|
|
// Free the input descriptors.
|
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
|
if (!inputs[i].isRef())
|
|
continue;
|
|
std::free(inputUnrankedMemrefs[i].descriptor);
|
|
}
|
|
}
|
|
|
|
static InputArgInfo
|
|
getExternalInputArgInfo(const refbackrt::InputDescriptor &inputDescriptor) {
|
|
InputArgInfo ret;
|
|
|
|
// Set arg / element types accordingly
|
|
switch (inputDescriptor.abiType) {
|
|
case ABIArgType::kNone:
|
|
ret.argType = ArgType::kNone;
|
|
ret.elementType = ElementType::NONE;
|
|
break;
|
|
case ABIArgType::kMemref:
|
|
ret.argType = ArgType::kTensor;
|
|
ret.elementType = ElementType::F32;
|
|
break;
|
|
case ABIArgType::kF32:
|
|
ret.argType = ArgType::kF32;
|
|
ret.elementType = ElementType::NONE;
|
|
break;
|
|
case ABIArgType::kF64:
|
|
ret.argType = ArgType::kF64;
|
|
ret.elementType = ElementType::NONE;
|
|
break;
|
|
}
|
|
|
|
// Extract shape information
|
|
ret.rank = inputDescriptor.rank;
|
|
for (int i = 0; i < inputDescriptor.rank; i++) {
|
|
ret.extents[i] = inputDescriptor.extents[i];
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
static OutputArgInfo
|
|
getExternalOutputArgInfo(const refbackrt::OutputDescriptor &outputDescriptor) {
|
|
OutputArgInfo ret;
|
|
|
|
// Set arg / element types accordingly
|
|
switch (outputDescriptor.abiType) {
|
|
case ABIArgType::kNone:
|
|
ret.argType = ArgType::kNone;
|
|
ret.elementType = ElementType::NONE;
|
|
break;
|
|
case ABIArgType::kMemref:
|
|
ret.argType = ArgType::kTensor;
|
|
ret.elementType = ElementType::F32;
|
|
break;
|
|
case ABIArgType::kF32:
|
|
ret.argType = ArgType::kF32;
|
|
ret.elementType = ElementType::NONE;
|
|
break;
|
|
case ABIArgType::kF64:
|
|
ret.argType = ArgType::kF64;
|
|
ret.elementType = ElementType::NONE;
|
|
break;
|
|
}
|
|
|
|
// Extract shape information
|
|
ret.rank = outputDescriptor.rank;
|
|
for (int i = 0; i < outputDescriptor.rank; i++) {
|
|
ret.extents[i] = outputDescriptor.extents[i];
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
|
StringRef functionName,
|
|
FunctionMetadata &outMetadata) {
|
|
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
|
if (!descriptor)
|
|
return failure();
|
|
outMetadata.numInputs = descriptor->numInputs;
|
|
outMetadata.numOutputs = descriptor->numOutputs;
|
|
|
|
for (int i = 0; i < descriptor->numInputs; i++) {
|
|
outMetadata.inputArgInfos[i] =
|
|
getExternalInputArgInfo(descriptor->inputDescriptors[i]);
|
|
}
|
|
|
|
for (int i = 0; i < descriptor->numOutputs; i++) {
|
|
outMetadata.outputArgInfos[i] =
|
|
getExternalOutputArgInfo(descriptor->outputDescriptors[i]);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult refbackrt::checkRtValueShapes(const RtValue &value,
|
|
const InputArgInfo &info) {
|
|
if (value.isTensor()) {
|
|
auto refTensor = value.toTensor();
|
|
|
|
// Don't bother checking shapes for unranked tensors
|
|
if (info.rank < 0)
|
|
return success();
|
|
|
|
if (refTensor->getRank() != info.rank)
|
|
return failure();
|
|
|
|
auto tensorExtents = refTensor->getExtents();
|
|
for (int i = 0; i < info.rank; i++) {
|
|
// If a dimension is dynamic, it is encoded as extent = -1
|
|
// and we should skip checking over that dimension
|
|
if (info.extents[i] > 0 && (info.extents[i] != tensorExtents[i]))
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult refbackrt::checkRtValueArgTypes(const RtValue &value,
|
|
const InputArgInfo &info) {
|
|
// Generic checks based on argType(s)
|
|
if ((value.isTensor() && info.argType != ArgType::kTensor) ||
|
|
(value.isFloat() && info.argType != ArgType::kF32))
|
|
return failure();
|
|
|
|
if (value.isRef()) {
|
|
// Will need special error checking for ref-counted types
|
|
// Currently only f32 tensors are supported
|
|
if (value.isTensor()) {
|
|
auto refTensor = value.toTensor();
|
|
if (refTensor->getElementType() != ElementType::F32)
|
|
return failure();
|
|
} else {
|
|
assert(false && "Unsupported input type checking for Ref type");
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
RtValue refbackrt::createRtValueFromOutputArgInfo(const OutputArgInfo &info) {
|
|
constexpr int32_t kDynamicConstantShape = 100;
|
|
switch (info.argType) {
|
|
case ArgType::kTensor: {
|
|
// HACK: for dynamic dims the shape will be negative, so for now we are
|
|
// just going to create a tensor of size kDynamicConstantShape
|
|
std::array<int32_t, kMaxRank> tensorShape;
|
|
for (int i = 0; i < info.rank; i++) {
|
|
tensorShape[i] =
|
|
info.extents[i] > 0 ? info.extents[i] : kDynamicConstantShape;
|
|
}
|
|
refbackrt::ArrayRef<int32_t> shape(tensorShape.data(), info.rank);
|
|
int numel = 1;
|
|
for (int i = 0; i < info.rank; i++)
|
|
numel *= shape[i];
|
|
|
|
void *data;
|
|
switch (info.elementType) {
|
|
case ElementType::F32: {
|
|
auto byteSize = numel * sizeof(float);
|
|
data = static_cast<void *>(aligned_alloc(32, byteSize));
|
|
memset(data, 0, byteSize);
|
|
return RtValue(Tensor::create(shape, ElementType::F32, data));
|
|
break;
|
|
}
|
|
default: { assert(false && "unknown output tensor type"); }
|
|
}
|
|
|
|
// The Tensor::create function will malloc and memcpy the data
|
|
// into the Tensor object, so after we need to free our
|
|
// temporary data buffer
|
|
assert(data && "data ptr must exist");
|
|
auto refTensor = Tensor::create(shape, ElementType::F32, data);
|
|
free(data);
|
|
return RtValue(refTensor);
|
|
}
|
|
case ArgType::kF32: {
|
|
return RtValue(-20.0f);
|
|
}
|
|
default: {
|
|
assert(false && "Don't know how to handle this artType");
|
|
return RtValue();
|
|
}
|
|
}
|
|
}
|