[refbackrt] Update Invoke API to support more than just Tensor's (#181)

pull/186/head
Bryce Arden 2021-03-10 17:39:26 -06:00 committed by GitHub
parent 8f9d4f917d
commit e7a8fd76e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 187 additions and 26 deletions

View File

@ -40,9 +40,9 @@ public:
fromCompiledModule(mlir::ModuleOp module,
llvm::ArrayRef<llvm::StringRef> sharedLibs);
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs);
llvm::ArrayRef<refbackrt::RtValue> inputs);
private:
JITModule();

View File

@ -27,6 +27,17 @@
namespace refbackrt {
struct RtValue;
// Base class for any RefCounted object type
class RefTarget {
protected:
template <typename T> friend class Ref;
mutable std::atomic<size_t> refCount;
constexpr RefTarget() noexcept : refCount(0) {}
};
// Reference-counted handle to a type with a `refCount` member.
template <typename T> class Ref {
public:
@ -36,7 +47,7 @@ public:
Ref(T *rawPtr) {
assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!");
ptr = rawPtr;
ptr->refCount += 1;
incref(ptr);
}
Ref(const Ref &other) {
ptr = other.ptr;
@ -73,11 +84,14 @@ public:
int debugGetRefCount() { return ptr->refCount; }
private:
friend struct RtValue;
static void incref(T *ptr) {
if (!ptr)
return;
ptr->refCount += 1;
}
friend struct RtValue;
static void decref(T *ptr) {
if (!ptr)
return;
@ -96,7 +110,7 @@ enum class ElementType : std::int32_t {
std::int32_t getElementTypeByteSize(ElementType type);
// Representation of a tensor.
class Tensor {
class Tensor : public RefTarget {
public:
// Due to tail-allocated objects, this struct should never be directly
// constructed.
@ -132,9 +146,6 @@ private:
auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
return MutableArrayRef<std::int32_t>(tail, rank);
}
// Reference count management.
template <typename T> friend class Ref;
std::atomic<int> refCount{0};
ElementType elementType;
// The number of dimensions of this Tensor.
@ -150,6 +161,153 @@ private:
// Sizes are tail-allocated.
};
// RtValue is a generic tagged union used to hold all value types
// The tag determines the type, and the payload represents the stored
// contents of an object. If an object is not trivially destructible,
// then it must be refcounted and must have a refCount.
#define NPCOMP_FORALL_PRIM_TAGS(_) \
_(None) \
_(Bool) \
_(Int) \
_(Double)
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
#define NPCOMP_FORALL_TAGS(_) \
NPCOMP_FORALL_PRIM_TAGS(_) \
NPCOMP_FORALL_REF_TAGS(_)
struct RtValue final {
RtValue() : payload{0}, tag(Tag::None) {}
// Bool
RtValue(bool b) : tag(Tag::Bool) { payload.asBool = b; }
bool isBool() const { return Tag::Bool == tag; }
bool toBool() const {
assert(isBool());
return payload.asBool;
}
// Int
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
bool isInt() const { return Tag::Int == tag; }
bool toInt() const {
assert(isInt());
return payload.asInt;
}
// Double
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
bool isDouble() const { return Tag::Double == tag; }
bool toDouble() const {
assert(isDouble());
return payload.asDouble;
}
// Tensor
RtValue(Ref<Tensor> tensor) : tag(Tag::Tensor) {
payload.asVoidPtr = reinterpret_cast<void *>(tensor.takePtr());
}
bool isTensor() const { return Tag::Tensor == tag; }
Ref<Tensor> toTensor() const {
assert(isTensor());
return Ref<Tensor>(reinterpret_cast<Tensor *>(payload.asVoidPtr));
}
// Ref
bool isRef() const {
#define DEFINE_IS_REF(x) \
if (is##x()) { \
return true; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_IS_REF)
#undef DEFINE_IS_REF
return false;
}
// RtValue (downcast)
const RtValue &toRtValue() const { return *this; }
RtValue &toRtValue() { return *this; }
// Stringify tag for debugging.
StringRef tagKind() const {
switch (tag) {
#define DEFINE_CASE(x) \
case Tag::x: \
return #x;
NPCOMP_FORALL_TAGS(DEFINE_CASE)
#undef DEFINE_CASE
}
// TODO(brycearden): Print tag here
return "InvalidTag!";
}
RtValue(const RtValue &rhs) : RtValue(rhs.payload, rhs.tag) {
if (isRef()) {
#define DEFINE_INCREF(x) \
if (is##x()) { \
Ref<x>::incref(static_cast<x *>(payload.asVoidPtr)); \
return; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_INCREF)
#undef DEFINE_INCREF
assert(false && "Unsupported RtValue type");
}
}
RtValue(RtValue &&rhs) noexcept : RtValue() { swap(rhs); }
RtValue &operator=(RtValue &&rhs) & noexcept {
RtValue(std::move(rhs)).swap(*this); // this also sets rhs to None
return *this;
}
RtValue &operator=(RtValue const &rhs) & {
RtValue(rhs).swap(*this);
return *this;
}
~RtValue() {
if (isRef()) {
#define DEFINE_DECREF(x) \
if (is##x()) { \
Ref<x>::decref(static_cast<x *>(payload.asVoidPtr)); \
return; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_DECREF)
#undef DEFINE_DECREF
assert(false && "Unsupported RtValue type");
}
}
private:
void swap(RtValue &rhs) {
std::swap(payload, rhs.payload);
std::swap(tag, rhs.tag);
}
// NOTE: Runtime tags are intentionally private.
// Please use the helper functions above to query information about the type
// of a RtValue.
enum class Tag : std::uint32_t {
#define DEFINE_TAG(x) x,
NPCOMP_FORALL_TAGS(DEFINE_TAG)
#undef DEFINE_TAG
};
union Payload {
bool asBool;
int64_t asInt;
double asDouble;
void *asVoidPtr;
};
RtValue(Payload pl, Tag tag) : payload(pl), tag(tag) {}
Payload payload;
Tag tag;
};
//===----------------------------------------------------------------------===//
// Module loading.
// This is the main entry point that users interact with.
@ -172,7 +330,7 @@ constexpr static int kMaxArity = 20;
// Low-level invocation API. The number of inputs and outputs should be correct
// and match the results of getMetadata.
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
ArrayRef<Ref<Tensor>> inputs, MutableArrayRef<Ref<Tensor>> outputs);
ArrayRef<RtValue> inputs, MutableArrayRef<RtValue> outputs);
// Metadata for function `functionName`.
//

View File

@ -22,6 +22,7 @@ using llvm::Twine;
using refback::JITModule;
using refbackrt::Ref;
using refbackrt::Tensor;
using refbackrt::RtValue;
template <typename T>
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
@ -106,18 +107,18 @@ void npcomp::python::defineBackendRefJitModule(py::module &m) {
[](JITModule &self, std::string functionName,
std::vector<py::buffer> inputs) {
// Prepare inputs.
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
inputTensors.reserve(inputs.size());
llvm::SmallVector<RtValue, 4> inputValues;
inputValues.reserve(inputs.size());
for (py::buffer &inputBuffer : inputs) {
inputTensors.push_back(copyBufferToTensor(inputBuffer));
inputValues.push_back(copyBufferToTensor(inputBuffer));
}
auto outputs = checkError(self.invoke(functionName, inputTensors),
auto outputs = checkError(self.invoke(functionName, inputValues),
"error invoking JIT function: ");
std::vector<py::array> outputArrays;
outputArrays.reserve(outputs.size());
for (Ref<Tensor> &outputTensor : outputs) {
outputArrays.push_back(wrapTensorAsArray(outputTensor));
for (RtValue &outputTensor : outputs) {
outputArrays.push_back(wrapTensorAsArray(outputTensor.toTensor()));
}
return outputArrays;
},

View File

@ -73,14 +73,14 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
}
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
JITModule::invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs) {
llvm::ArrayRef<refbackrt::RtValue> inputs) {
refbackrt::FunctionMetadata metadata;
if (refbackrt::failed(refbackrt::getMetadata(
descriptor, toRefbackrt(functionName), metadata)))
return make_string_error("unknown function: " + Twine(functionName));
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> outputs(
SmallVector<refbackrt::RtValue, 6> outputs(
metadata.numOutputs);
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
return make_string_error("invoking '" + Twine(functionName) +

View File

@ -131,7 +131,7 @@ Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
auto *tensor = static_cast<Tensor *>(
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
tensor->refCount.store(0);
tensor->refCount = 0;
tensor->elementType = type;
tensor->rank = extents.size();
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
@ -172,8 +172,8 @@ static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
}
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
StringRef functionName, ArrayRef<Ref<Tensor>> inputs,
MutableArrayRef<Ref<Tensor>> outputs) {
StringRef functionName, ArrayRef<RtValue> inputs,
MutableArrayRef<RtValue> outputs) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
assert(descriptor && "unknown function name");
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
@ -191,7 +191,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// more complex though (and maybe impossible given the current abstractions).
for (int i = 0, e = inputs.size(); i < e; i++) {
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
}
// Create a type-erased list of "packed inputs" to pass to the
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
@ -224,7 +224,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
elementType);
outputs[i] = Ref<Tensor>(tensor);
outputs[i] = RtValue(Ref<Tensor>(tensor));
}
// Now, we just need to free all the UnrankedMemref's that we created.

View File

@ -58,14 +58,15 @@ convertAttrToTensor(Attribute attr) {
return make_string_error("unhandled argument");
}
static Expected<SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
static Expected<SmallVector<refbackrt::RtValue, 6>>
createInputs(ArrayRef<StringRef> argValues) {
MLIRContext context;
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> ret;
SmallVector<refbackrt::RtValue, 6> ret;
for (auto argValue : argValues) {
auto attr = parseAttribute(argValue, &context);
if (!attr)
return make_string_error(Twine("could not parse arg value: ") + argValue);
// TODO(brycearden): Handle multiple input types
auto expectedTensor = convertAttrToTensor(attr);
if (!expectedTensor)
return expectedTensor.takeError();
@ -111,11 +112,12 @@ static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
attr.print(os);
}
static void printOutputs(ArrayRef<refbackrt::Ref<refbackrt::Tensor>> outputs,
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
llvm::raw_ostream &os) {
for (auto output : llvm::enumerate(outputs)) {
assert(output.value().isTensor() && "only tensor outputs are supported.");
os << "output #" << output.index() << ": ";
printOutput(*output.value(), os);
printOutput(*output.value().toTensor().get(), os);
os << "\n";
}
}