diff --git a/include/npcomp/RefBackend/JITHelpers/JITModule.h b/include/npcomp/RefBackend/JITHelpers/JITModule.h index 1476a7d9a..b252cec59 100644 --- a/include/npcomp/RefBackend/JITHelpers/JITModule.h +++ b/include/npcomp/RefBackend/JITHelpers/JITModule.h @@ -40,9 +40,9 @@ public: fromCompiledModule(mlir::ModuleOp module, llvm::ArrayRef sharedLibs); - llvm::Expected, 6>> + llvm::Expected> invoke(llvm::StringRef functionName, - llvm::ArrayRef> inputs); + llvm::ArrayRef inputs); private: JITModule(); diff --git a/include/npcomp/RefBackend/Runtime/UserAPI.h b/include/npcomp/RefBackend/Runtime/UserAPI.h index 8ba68cb1e..b0a696af9 100644 --- a/include/npcomp/RefBackend/Runtime/UserAPI.h +++ b/include/npcomp/RefBackend/Runtime/UserAPI.h @@ -27,6 +27,17 @@ namespace refbackrt { +struct RtValue; + +// Base class for any RefCounted object type +class RefTarget { +protected: + template friend class Ref; + mutable std::atomic refCount; + + constexpr RefTarget() noexcept : refCount(0) {} +}; + // Reference-counted handle to a type with a `refCount` member. template class Ref { public: @@ -36,7 +47,7 @@ public: Ref(T *rawPtr) { assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!"); ptr = rawPtr; - ptr->refCount += 1; + incref(ptr); } Ref(const Ref &other) { ptr = other.ptr; @@ -73,11 +84,14 @@ public: int debugGetRefCount() { return ptr->refCount; } private: + friend struct RtValue; static void incref(T *ptr) { if (!ptr) return; ptr->refCount += 1; } + + friend struct RtValue; static void decref(T *ptr) { if (!ptr) return; @@ -96,7 +110,7 @@ enum class ElementType : std::int32_t { std::int32_t getElementTypeByteSize(ElementType type); // Representation of a tensor. -class Tensor { +class Tensor : public RefTarget { public: // Due to tail-allocated objects, this struct should never be directly // constructed. @@ -132,9 +146,6 @@ private: auto *tail = reinterpret_cast(this + 1); return MutableArrayRef(tail, rank); } - // Reference count management. - template friend class Ref; - std::atomic refCount{0}; ElementType elementType; // The number of dimensions of this Tensor. @@ -150,6 +161,153 @@ private: // Sizes are tail-allocated. }; +// RtValue is a generic tagged union used to hold all value types +// The tag determines the type, and the payload represents the stored +// contents of an object. If an object is not trivially destructible, +// then it must be refcounted and must have a refCount. +#define NPCOMP_FORALL_PRIM_TAGS(_) \ + _(None) \ + _(Bool) \ + _(Int) \ + _(Double) + +#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor) + +#define NPCOMP_FORALL_TAGS(_) \ + NPCOMP_FORALL_PRIM_TAGS(_) \ + NPCOMP_FORALL_REF_TAGS(_) + +struct RtValue final { + + RtValue() : payload{0}, tag(Tag::None) {} + + // Bool + RtValue(bool b) : tag(Tag::Bool) { payload.asBool = b; } + bool isBool() const { return Tag::Bool == tag; } + bool toBool() const { + assert(isBool()); + return payload.asBool; + } + + // Int + RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; } + RtValue(std::int32_t i) : RtValue(static_cast(i)) {} + bool isInt() const { return Tag::Int == tag; } + bool toInt() const { + assert(isInt()); + return payload.asInt; + } + + // Double + RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; } + bool isDouble() const { return Tag::Double == tag; } + bool toDouble() const { + assert(isDouble()); + return payload.asDouble; + } + + // Tensor + RtValue(Ref tensor) : tag(Tag::Tensor) { + payload.asVoidPtr = reinterpret_cast(tensor.takePtr()); + } + bool isTensor() const { return Tag::Tensor == tag; } + Ref toTensor() const { + assert(isTensor()); + return Ref(reinterpret_cast(payload.asVoidPtr)); + } + + // Ref + bool isRef() const { +#define DEFINE_IS_REF(x) \ + if (is##x()) { \ + return true; \ + } + NPCOMP_FORALL_REF_TAGS(DEFINE_IS_REF) +#undef DEFINE_IS_REF + return false; + } + + // RtValue (downcast) + const RtValue &toRtValue() const { return *this; } + RtValue &toRtValue() { return *this; } + + // Stringify tag for debugging. + StringRef tagKind() const { + switch (tag) { +#define DEFINE_CASE(x) \ + case Tag::x: \ + return #x; + NPCOMP_FORALL_TAGS(DEFINE_CASE) +#undef DEFINE_CASE + } + // TODO(brycearden): Print tag here + return "InvalidTag!"; + } + + RtValue(const RtValue &rhs) : RtValue(rhs.payload, rhs.tag) { + if (isRef()) { +#define DEFINE_INCREF(x) \ + if (is##x()) { \ + Ref::incref(static_cast(payload.asVoidPtr)); \ + return; \ + } + NPCOMP_FORALL_REF_TAGS(DEFINE_INCREF) +#undef DEFINE_INCREF + assert(false && "Unsupported RtValue type"); + } + } + RtValue(RtValue &&rhs) noexcept : RtValue() { swap(rhs); } + + RtValue &operator=(RtValue &&rhs) & noexcept { + RtValue(std::move(rhs)).swap(*this); // this also sets rhs to None + return *this; + } + RtValue &operator=(RtValue const &rhs) & { + RtValue(rhs).swap(*this); + return *this; + } + + ~RtValue() { + if (isRef()) { +#define DEFINE_DECREF(x) \ + if (is##x()) { \ + Ref::decref(static_cast(payload.asVoidPtr)); \ + return; \ + } + NPCOMP_FORALL_REF_TAGS(DEFINE_DECREF) +#undef DEFINE_DECREF + assert(false && "Unsupported RtValue type"); + } + } + +private: + void swap(RtValue &rhs) { + std::swap(payload, rhs.payload); + std::swap(tag, rhs.tag); + } + + // NOTE: Runtime tags are intentionally private. + // Please use the helper functions above to query information about the type + // of a RtValue. + enum class Tag : std::uint32_t { +#define DEFINE_TAG(x) x, + NPCOMP_FORALL_TAGS(DEFINE_TAG) +#undef DEFINE_TAG + }; + + union Payload { + bool asBool; + int64_t asInt; + double asDouble; + void *asVoidPtr; + }; + + RtValue(Payload pl, Tag tag) : payload(pl), tag(tag) {} + + Payload payload; + Tag tag; +}; + //===----------------------------------------------------------------------===// // Module loading. // This is the main entry point that users interact with. @@ -172,7 +330,7 @@ constexpr static int kMaxArity = 20; // Low-level invocation API. The number of inputs and outputs should be correct // and match the results of getMetadata. void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName, - ArrayRef> inputs, MutableArrayRef> outputs); + ArrayRef inputs, MutableArrayRef outputs); // Metadata for function `functionName`. // diff --git a/lib/Backend/RefJIT/PythonModule.cpp b/lib/Backend/RefJIT/PythonModule.cpp index bdfb760ca..8bb57c6d9 100644 --- a/lib/Backend/RefJIT/PythonModule.cpp +++ b/lib/Backend/RefJIT/PythonModule.cpp @@ -22,6 +22,7 @@ using llvm::Twine; using refback::JITModule; using refbackrt::Ref; using refbackrt::Tensor; +using refbackrt::RtValue; template static T checkError(llvm::Expected &&expected, Twine banner = {}) { @@ -106,18 +107,18 @@ void npcomp::python::defineBackendRefJitModule(py::module &m) { [](JITModule &self, std::string functionName, std::vector inputs) { // Prepare inputs. - llvm::SmallVector, 4> inputTensors; - inputTensors.reserve(inputs.size()); + llvm::SmallVector inputValues; + inputValues.reserve(inputs.size()); for (py::buffer &inputBuffer : inputs) { - inputTensors.push_back(copyBufferToTensor(inputBuffer)); + inputValues.push_back(copyBufferToTensor(inputBuffer)); } - auto outputs = checkError(self.invoke(functionName, inputTensors), + auto outputs = checkError(self.invoke(functionName, inputValues), "error invoking JIT function: "); std::vector outputArrays; outputArrays.reserve(outputs.size()); - for (Ref &outputTensor : outputs) { - outputArrays.push_back(wrapTensorAsArray(outputTensor)); + for (RtValue &outputTensor : outputs) { + outputArrays.push_back(wrapTensorAsArray(outputTensor.toTensor())); } return outputArrays; }, diff --git a/lib/RefBackend/JITHelpers/JITModule.cpp b/lib/RefBackend/JITHelpers/JITModule.cpp index 6ddabd2b2..0530df2bd 100644 --- a/lib/RefBackend/JITHelpers/JITModule.cpp +++ b/lib/RefBackend/JITHelpers/JITModule.cpp @@ -73,14 +73,14 @@ static refbackrt::MutableArrayRef toRefbackrt(llvm::MutableArrayRef a) { return refbackrt::MutableArrayRef(a.data(), a.size()); } -llvm::Expected, 6>> +llvm::Expected> JITModule::invoke(llvm::StringRef functionName, - llvm::ArrayRef> inputs) { + llvm::ArrayRef inputs) { refbackrt::FunctionMetadata metadata; if (refbackrt::failed(refbackrt::getMetadata( descriptor, toRefbackrt(functionName), metadata))) return make_string_error("unknown function: " + Twine(functionName)); - SmallVector, 6> outputs( + SmallVector outputs( metadata.numOutputs); if (metadata.numInputs != static_cast(inputs.size())) return make_string_error("invoking '" + Twine(functionName) + diff --git a/lib/RefBackend/Runtime/Runtime.cpp b/lib/RefBackend/Runtime/Runtime.cpp index d92dc0de7..dce733ad7 100644 --- a/lib/RefBackend/Runtime/Runtime.cpp +++ b/lib/RefBackend/Runtime/Runtime.cpp @@ -131,7 +131,7 @@ Tensor *Tensor::createRaw(ArrayRef extents, ElementType type, auto *tensor = static_cast( std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t))); - tensor->refCount.store(0); + tensor->refCount = 0; tensor->elementType = type; tensor->rank = extents.size(); auto byteSize = getElementTypeByteSize(type) * totalElements(extents); @@ -172,8 +172,8 @@ static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor, } void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, - StringRef functionName, ArrayRef> inputs, - MutableArrayRef> outputs) { + StringRef functionName, ArrayRef inputs, + MutableArrayRef outputs) { auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName); assert(descriptor && "unknown function name"); assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity"); @@ -191,7 +191,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, // more complex though (and maybe impossible given the current abstractions). for (int i = 0, e = inputs.size(); i < e; i++) { inputUnrankedMemrefs[i] = - convertRefbackrtTensorToUnrankedMemref(inputs[i].get()); + convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get()); } // Create a type-erased list of "packed inputs" to pass to the // LLVM/C ABI wrapper function. Each packedInput pointer corresponds to @@ -224,7 +224,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor( outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor, elementType); - outputs[i] = Ref(tensor); + outputs[i] = RtValue(Ref(tensor)); } // Now, we just need to free all the UnrankedMemref's that we created. diff --git a/tools/npcomp-run-mlir/npcomp-run-mlir.cpp b/tools/npcomp-run-mlir/npcomp-run-mlir.cpp index 3f8757d75..c0d63d90b 100644 --- a/tools/npcomp-run-mlir/npcomp-run-mlir.cpp +++ b/tools/npcomp-run-mlir/npcomp-run-mlir.cpp @@ -58,14 +58,15 @@ convertAttrToTensor(Attribute attr) { return make_string_error("unhandled argument"); } -static Expected, 6>> +static Expected> createInputs(ArrayRef argValues) { MLIRContext context; - SmallVector, 6> ret; + SmallVector ret; for (auto argValue : argValues) { auto attr = parseAttribute(argValue, &context); if (!attr) return make_string_error(Twine("could not parse arg value: ") + argValue); + // TODO(brycearden): Handle multiple input types auto expectedTensor = convertAttrToTensor(attr); if (!expectedTensor) return expectedTensor.takeError(); @@ -111,11 +112,12 @@ static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) { attr.print(os); } -static void printOutputs(ArrayRef> outputs, +static void printOutputs(ArrayRef outputs, llvm::raw_ostream &os) { for (auto output : llvm::enumerate(outputs)) { + assert(output.value().isTensor() && "only tensor outputs are supported."); os << "output #" << output.index() << ": "; - printOutput(*output.value(), os); + printOutput(*output.value().toTensor().get(), os); os << "\n"; } }