mirror of https://github.com/llvm/torch-mlir
184 lines
5.5 KiB
C
184 lines
5.5 KiB
C
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// This is the public-facing interface for interacting with the npcomp
|
||
|
// runtime.
|
||
|
//
|
||
|
// This functionality is totally firewalled from the compiler codebase, so
|
||
|
// even if things superficially look similar, remember that there are no
|
||
|
// LLVM utilities here, memory allocation should be kept to a minimum, etc.
|
||
|
//
|
||
|
// npcomp/runtime/Support.h provides some minimal LLVM-like support code to keep
|
||
|
// the API familiar.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#ifndef NPCOMP_RUNTIME_USERAPI_H
|
||
|
#define NPCOMP_RUNTIME_USERAPI_H
|
||
|
|
||
|
#include "npcomp/runtime/Support.h"
|
||
|
#include <atomic>
|
||
|
#include <cstdlib>
|
||
|
|
||
|
namespace npcomprt {
|
||
|
|
||
|
// Reference-counted handle to a type with a `refCount` member.
|
||
|
template <typename T> class Ref {
|
||
|
public:
|
||
|
Ref() { ptr = nullptr; }
|
||
|
// Creates a Ref and increments the refcount by 1.
|
||
|
// rawPtr must be allocated with std::malloc.
|
||
|
Ref(T *rawPtr) {
|
||
|
ptr = rawPtr;
|
||
|
ptr->refCount += 1;
|
||
|
}
|
||
|
Ref(const Ref &other) {
|
||
|
ptr = other.ptr;
|
||
|
incref(ptr);
|
||
|
}
|
||
|
Ref(Ref &&other) { ptr = other.takePtr(); }
|
||
|
Ref &operator=(const Ref &other) {
|
||
|
if (&other == this)
|
||
|
return *this;
|
||
|
decref(ptr);
|
||
|
ptr = other.ptr;
|
||
|
incref(ptr);
|
||
|
return *this;
|
||
|
}
|
||
|
Ref &operator=(Ref &&other) {
|
||
|
if (&other == this)
|
||
|
return *this;
|
||
|
decref(ptr);
|
||
|
ptr = other.takePtr();
|
||
|
return *this;
|
||
|
}
|
||
|
~Ref() { decref(ptr); }
|
||
|
|
||
|
T &operator*() const { return *ptr; }
|
||
|
T *operator->() const { return ptr; }
|
||
|
T *get() const { return ptr; }
|
||
|
|
||
|
T *takePtr() {
|
||
|
auto *ret = ptr;
|
||
|
ptr = nullptr;
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
static void incref(T *ptr) {
|
||
|
if (!ptr)
|
||
|
return;
|
||
|
ptr->refCount += 1;
|
||
|
}
|
||
|
static void decref(T *ptr) {
|
||
|
if (!ptr)
|
||
|
return;
|
||
|
if (ptr->refCount.fetch_sub(1) == 1) {
|
||
|
ptr->~T();
|
||
|
std::free(static_cast<void *>(ptr));
|
||
|
}
|
||
|
}
|
||
|
T *ptr;
|
||
|
};
|
||
|
|
||
|
// The available data types.
|
||
|
enum class ElementType : std::int32_t {
|
||
|
F32,
|
||
|
};
|
||
|
std::int32_t getElementTypeByteSize(ElementType type);
|
||
|
|
||
|
// Representation of a tensor.
|
||
|
class Tensor {
|
||
|
public:
|
||
|
// Due to tail-allocated objects, this struct should never be directly
|
||
|
// constructed.
|
||
|
Tensor() = delete;
|
||
|
|
||
|
// Create a Tensor with the given extents and element type, with a buffer
|
||
|
// holding a copy of `data`.
|
||
|
static Ref<Tensor> create(ArrayRef<std::int32_t> extents,
|
||
|
ElementType elementType, void *data);
|
||
|
// Same as `create`, but returns a raw pointer.
|
||
|
static Tensor *createRaw(ArrayRef<std::int32_t> extents,
|
||
|
ElementType elementType, void *data);
|
||
|
|
||
|
ElementType getElementType() const { return elementType; }
|
||
|
std::int32_t getRank() const { return rank; }
|
||
|
void *getData() const { return data; }
|
||
|
template <typename T> T *getData() const { return static_cast<T *>(data); }
|
||
|
std::int32_t getExtent(int dimension) const {
|
||
|
return getExtents()[dimension];
|
||
|
}
|
||
|
ArrayRef<std::int32_t> getExtents() const {
|
||
|
auto extents = const_cast<Tensor *>(this)->getMutableExtents();
|
||
|
return ArrayRef<std::int32_t>(extents.data(), extents.size());
|
||
|
}
|
||
|
// Returns the number of bytes occupied by the data representing this tensor.
|
||
|
// The total allocated amount might be higher to allow e.g. for alignment
|
||
|
// nudging.
|
||
|
std::int32_t getDataByteSize() const;
|
||
|
~Tensor() { std::free(allocatedPtr); }
|
||
|
|
||
|
private:
|
||
|
MutableArrayRef<std::int32_t> getMutableExtents() {
|
||
|
auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
|
||
|
return MutableArrayRef<std::int32_t>(tail, rank);
|
||
|
}
|
||
|
// Reference count management.
|
||
|
template <typename T> friend class Ref;
|
||
|
std::atomic<int> refCount{0};
|
||
|
|
||
|
ElementType elementType;
|
||
|
// The number of dimensions of this Tensor.
|
||
|
// There are `rank` tail-allocated std::int32_t values representing the
|
||
|
// tensor extents.
|
||
|
std::int32_t rank;
|
||
|
// The buffer base.
|
||
|
void *data;
|
||
|
// The raw pointer returned by the allocator (currently assumed to be
|
||
|
// malloc), suitable for freeing the buffer.
|
||
|
void *allocatedPtr;
|
||
|
|
||
|
// Sizes are tail-allocated.
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Module loading.
|
||
|
// This is the main entry point that users interact with.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
// Metadata for a particular function.
|
||
|
// TODO: Add arg types.
|
||
|
struct FunctionMetadata {
|
||
|
std::int32_t numInputs;
|
||
|
std::int32_t numOutputs;
|
||
|
};
|
||
|
|
||
|
// Opaque forward declaration of module descriptor type. This is the type
|
||
|
// created by the compiler in the module binary.
|
||
|
struct ModuleDescriptor;
|
||
|
|
||
|
// Maximum input or output arity.
|
||
|
constexpr static int kMaxArity = 20;
|
||
|
|
||
|
// Low-level invocation API. The number of inputs and outputs should be correct
|
||
|
// and match the results of getMetadata.
|
||
|
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
|
||
|
ArrayRef<Ref<Tensor>> inputs, MutableArrayRef<Ref<Tensor>> outputs);
|
||
|
|
||
|
// Metadata for function `functionName`.
|
||
|
//
|
||
|
// Returns failure if functionName wasn't found.
|
||
|
LogicalResult getMetadata(ModuleDescriptor *moduleDescriptor,
|
||
|
StringRef functionName,
|
||
|
FunctionMetadata &outMetadata);
|
||
|
|
||
|
} // namespace npcomprt
|
||
|
|
||
|
#endif // NPCOMP_RUNTIME_USERAPI_H
|