[refbackrt] Update Invoke API to support more than just Tensor's (#181)

pull/186/head
Bryce Arden 2021-03-10 17:39:26 -06:00 committed by GitHub
parent 8f9d4f917d
commit e7a8fd76e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 187 additions and 26 deletions

View File

@ -40,9 +40,9 @@ public:
fromCompiledModule(mlir::ModuleOp module, fromCompiledModule(mlir::ModuleOp module,
llvm::ArrayRef<llvm::StringRef> sharedLibs); llvm::ArrayRef<llvm::StringRef> sharedLibs);
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>> llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
invoke(llvm::StringRef functionName, invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs); llvm::ArrayRef<refbackrt::RtValue> inputs);
private: private:
JITModule(); JITModule();

View File

@ -27,6 +27,17 @@
namespace refbackrt { namespace refbackrt {
struct RtValue;
// Base class for any RefCounted object type
class RefTarget {
protected:
template <typename T> friend class Ref;
mutable std::atomic<size_t> refCount;
constexpr RefTarget() noexcept : refCount(0) {}
};
// Reference-counted handle to a type with a `refCount` member. // Reference-counted handle to a type with a `refCount` member.
template <typename T> class Ref { template <typename T> class Ref {
public: public:
@ -36,7 +47,7 @@ public:
Ref(T *rawPtr) { Ref(T *rawPtr) {
assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!"); assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!");
ptr = rawPtr; ptr = rawPtr;
ptr->refCount += 1; incref(ptr);
} }
Ref(const Ref &other) { Ref(const Ref &other) {
ptr = other.ptr; ptr = other.ptr;
@ -73,11 +84,14 @@ public:
int debugGetRefCount() { return ptr->refCount; } int debugGetRefCount() { return ptr->refCount; }
private: private:
friend struct RtValue;
static void incref(T *ptr) { static void incref(T *ptr) {
if (!ptr) if (!ptr)
return; return;
ptr->refCount += 1; ptr->refCount += 1;
} }
friend struct RtValue;
static void decref(T *ptr) { static void decref(T *ptr) {
if (!ptr) if (!ptr)
return; return;
@ -96,7 +110,7 @@ enum class ElementType : std::int32_t {
std::int32_t getElementTypeByteSize(ElementType type); std::int32_t getElementTypeByteSize(ElementType type);
// Representation of a tensor. // Representation of a tensor.
class Tensor { class Tensor : public RefTarget {
public: public:
// Due to tail-allocated objects, this struct should never be directly // Due to tail-allocated objects, this struct should never be directly
// constructed. // constructed.
@ -132,9 +146,6 @@ private:
auto *tail = reinterpret_cast<std::int32_t *>(this + 1); auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
return MutableArrayRef<std::int32_t>(tail, rank); return MutableArrayRef<std::int32_t>(tail, rank);
} }
// Reference count management.
template <typename T> friend class Ref;
std::atomic<int> refCount{0};
ElementType elementType; ElementType elementType;
// The number of dimensions of this Tensor. // The number of dimensions of this Tensor.
@ -150,6 +161,153 @@ private:
// Sizes are tail-allocated. // Sizes are tail-allocated.
}; };
// RtValue is a generic tagged union used to hold all value types
// The tag determines the type, and the payload represents the stored
// contents of an object. If an object is not trivially destructible,
// then it must be refcounted and must have a refCount.
#define NPCOMP_FORALL_PRIM_TAGS(_) \
_(None) \
_(Bool) \
_(Int) \
_(Double)
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
#define NPCOMP_FORALL_TAGS(_) \
NPCOMP_FORALL_PRIM_TAGS(_) \
NPCOMP_FORALL_REF_TAGS(_)
struct RtValue final {
RtValue() : payload{0}, tag(Tag::None) {}
// Bool
RtValue(bool b) : tag(Tag::Bool) { payload.asBool = b; }
bool isBool() const { return Tag::Bool == tag; }
bool toBool() const {
assert(isBool());
return payload.asBool;
}
// Int
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
bool isInt() const { return Tag::Int == tag; }
bool toInt() const {
assert(isInt());
return payload.asInt;
}
// Double
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
bool isDouble() const { return Tag::Double == tag; }
bool toDouble() const {
assert(isDouble());
return payload.asDouble;
}
// Tensor
RtValue(Ref<Tensor> tensor) : tag(Tag::Tensor) {
payload.asVoidPtr = reinterpret_cast<void *>(tensor.takePtr());
}
bool isTensor() const { return Tag::Tensor == tag; }
Ref<Tensor> toTensor() const {
assert(isTensor());
return Ref<Tensor>(reinterpret_cast<Tensor *>(payload.asVoidPtr));
}
// Ref
bool isRef() const {
#define DEFINE_IS_REF(x) \
if (is##x()) { \
return true; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_IS_REF)
#undef DEFINE_IS_REF
return false;
}
// RtValue (downcast)
const RtValue &toRtValue() const { return *this; }
RtValue &toRtValue() { return *this; }
// Stringify tag for debugging.
StringRef tagKind() const {
switch (tag) {
#define DEFINE_CASE(x) \
case Tag::x: \
return #x;
NPCOMP_FORALL_TAGS(DEFINE_CASE)
#undef DEFINE_CASE
}
// TODO(brycearden): Print tag here
return "InvalidTag!";
}
RtValue(const RtValue &rhs) : RtValue(rhs.payload, rhs.tag) {
if (isRef()) {
#define DEFINE_INCREF(x) \
if (is##x()) { \
Ref<x>::incref(static_cast<x *>(payload.asVoidPtr)); \
return; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_INCREF)
#undef DEFINE_INCREF
assert(false && "Unsupported RtValue type");
}
}
RtValue(RtValue &&rhs) noexcept : RtValue() { swap(rhs); }
RtValue &operator=(RtValue &&rhs) & noexcept {
RtValue(std::move(rhs)).swap(*this); // this also sets rhs to None
return *this;
}
RtValue &operator=(RtValue const &rhs) & {
RtValue(rhs).swap(*this);
return *this;
}
~RtValue() {
if (isRef()) {
#define DEFINE_DECREF(x) \
if (is##x()) { \
Ref<x>::decref(static_cast<x *>(payload.asVoidPtr)); \
return; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_DECREF)
#undef DEFINE_DECREF
assert(false && "Unsupported RtValue type");
}
}
private:
void swap(RtValue &rhs) {
std::swap(payload, rhs.payload);
std::swap(tag, rhs.tag);
}
// NOTE: Runtime tags are intentionally private.
// Please use the helper functions above to query information about the type
// of a RtValue.
enum class Tag : std::uint32_t {
#define DEFINE_TAG(x) x,
NPCOMP_FORALL_TAGS(DEFINE_TAG)
#undef DEFINE_TAG
};
union Payload {
bool asBool;
int64_t asInt;
double asDouble;
void *asVoidPtr;
};
RtValue(Payload pl, Tag tag) : payload(pl), tag(tag) {}
Payload payload;
Tag tag;
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Module loading. // Module loading.
// This is the main entry point that users interact with. // This is the main entry point that users interact with.
@ -172,7 +330,7 @@ constexpr static int kMaxArity = 20;
// Low-level invocation API. The number of inputs and outputs should be correct // Low-level invocation API. The number of inputs and outputs should be correct
// and match the results of getMetadata. // and match the results of getMetadata.
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName, void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
ArrayRef<Ref<Tensor>> inputs, MutableArrayRef<Ref<Tensor>> outputs); ArrayRef<RtValue> inputs, MutableArrayRef<RtValue> outputs);
// Metadata for function `functionName`. // Metadata for function `functionName`.
// //

View File

@ -22,6 +22,7 @@ using llvm::Twine;
using refback::JITModule; using refback::JITModule;
using refbackrt::Ref; using refbackrt::Ref;
using refbackrt::Tensor; using refbackrt::Tensor;
using refbackrt::RtValue;
template <typename T> template <typename T>
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) { static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
@ -106,18 +107,18 @@ void npcomp::python::defineBackendRefJitModule(py::module &m) {
[](JITModule &self, std::string functionName, [](JITModule &self, std::string functionName,
std::vector<py::buffer> inputs) { std::vector<py::buffer> inputs) {
// Prepare inputs. // Prepare inputs.
llvm::SmallVector<Ref<Tensor>, 4> inputTensors; llvm::SmallVector<RtValue, 4> inputValues;
inputTensors.reserve(inputs.size()); inputValues.reserve(inputs.size());
for (py::buffer &inputBuffer : inputs) { for (py::buffer &inputBuffer : inputs) {
inputTensors.push_back(copyBufferToTensor(inputBuffer)); inputValues.push_back(copyBufferToTensor(inputBuffer));
} }
auto outputs = checkError(self.invoke(functionName, inputTensors), auto outputs = checkError(self.invoke(functionName, inputValues),
"error invoking JIT function: "); "error invoking JIT function: ");
std::vector<py::array> outputArrays; std::vector<py::array> outputArrays;
outputArrays.reserve(outputs.size()); outputArrays.reserve(outputs.size());
for (Ref<Tensor> &outputTensor : outputs) { for (RtValue &outputTensor : outputs) {
outputArrays.push_back(wrapTensorAsArray(outputTensor)); outputArrays.push_back(wrapTensorAsArray(outputTensor.toTensor()));
} }
return outputArrays; return outputArrays;
}, },

View File

@ -73,14 +73,14 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
return refbackrt::MutableArrayRef<T>(a.data(), a.size()); return refbackrt::MutableArrayRef<T>(a.data(), a.size());
} }
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>> llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
JITModule::invoke(llvm::StringRef functionName, JITModule::invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs) { llvm::ArrayRef<refbackrt::RtValue> inputs) {
refbackrt::FunctionMetadata metadata; refbackrt::FunctionMetadata metadata;
if (refbackrt::failed(refbackrt::getMetadata( if (refbackrt::failed(refbackrt::getMetadata(
descriptor, toRefbackrt(functionName), metadata))) descriptor, toRefbackrt(functionName), metadata)))
return make_string_error("unknown function: " + Twine(functionName)); return make_string_error("unknown function: " + Twine(functionName));
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> outputs( SmallVector<refbackrt::RtValue, 6> outputs(
metadata.numOutputs); metadata.numOutputs);
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size())) if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
return make_string_error("invoking '" + Twine(functionName) + return make_string_error("invoking '" + Twine(functionName) +

View File

@ -131,7 +131,7 @@ Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
auto *tensor = static_cast<Tensor *>( auto *tensor = static_cast<Tensor *>(
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t))); std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
tensor->refCount.store(0); tensor->refCount = 0;
tensor->elementType = type; tensor->elementType = type;
tensor->rank = extents.size(); tensor->rank = extents.size();
auto byteSize = getElementTypeByteSize(type) * totalElements(extents); auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
@ -172,8 +172,8 @@ static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
} }
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
StringRef functionName, ArrayRef<Ref<Tensor>> inputs, StringRef functionName, ArrayRef<RtValue> inputs,
MutableArrayRef<Ref<Tensor>> outputs) { MutableArrayRef<RtValue> outputs) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName); auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
assert(descriptor && "unknown function name"); assert(descriptor && "unknown function name");
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity"); assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
@ -191,7 +191,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// more complex though (and maybe impossible given the current abstractions). // more complex though (and maybe impossible given the current abstractions).
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
inputUnrankedMemrefs[i] = inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].get()); convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
} }
// Create a type-erased list of "packed inputs" to pass to the // Create a type-erased list of "packed inputs" to pass to the
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to // LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
@ -224,7 +224,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor( Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor, outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
elementType); elementType);
outputs[i] = Ref<Tensor>(tensor); outputs[i] = RtValue(Ref<Tensor>(tensor));
} }
// Now, we just need to free all the UnrankedMemref's that we created. // Now, we just need to free all the UnrankedMemref's that we created.

View File

@ -58,14 +58,15 @@ convertAttrToTensor(Attribute attr) {
return make_string_error("unhandled argument"); return make_string_error("unhandled argument");
} }
static Expected<SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>> static Expected<SmallVector<refbackrt::RtValue, 6>>
createInputs(ArrayRef<StringRef> argValues) { createInputs(ArrayRef<StringRef> argValues) {
MLIRContext context; MLIRContext context;
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> ret; SmallVector<refbackrt::RtValue, 6> ret;
for (auto argValue : argValues) { for (auto argValue : argValues) {
auto attr = parseAttribute(argValue, &context); auto attr = parseAttribute(argValue, &context);
if (!attr) if (!attr)
return make_string_error(Twine("could not parse arg value: ") + argValue); return make_string_error(Twine("could not parse arg value: ") + argValue);
// TODO(brycearden): Handle multiple input types
auto expectedTensor = convertAttrToTensor(attr); auto expectedTensor = convertAttrToTensor(attr);
if (!expectedTensor) if (!expectedTensor)
return expectedTensor.takeError(); return expectedTensor.takeError();
@ -111,11 +112,12 @@ static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
attr.print(os); attr.print(os);
} }
static void printOutputs(ArrayRef<refbackrt::Ref<refbackrt::Tensor>> outputs, static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
llvm::raw_ostream &os) { llvm::raw_ostream &os) {
for (auto output : llvm::enumerate(outputs)) { for (auto output : llvm::enumerate(outputs)) {
assert(output.value().isTensor() && "only tensor outputs are supported.");
os << "output #" << output.index() << ": "; os << "output #" << output.index() << ": ";
printOutput(*output.value(), os); printOutput(*output.value().toTensor().get(), os);
os << "\n"; os << "\n";
} }
} }