mirror of https://github.com/llvm/torch-mlir
[refbackrt] Update Invoke API to support more than just Tensor's (#181)
parent
8f9d4f917d
commit
e7a8fd76e2
|
@ -40,9 +40,9 @@ public:
|
|||
fromCompiledModule(mlir::ModuleOp module,
|
||||
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
||||
|
||||
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
||||
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||
invoke(llvm::StringRef functionName,
|
||||
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs);
|
||||
llvm::ArrayRef<refbackrt::RtValue> inputs);
|
||||
|
||||
private:
|
||||
JITModule();
|
||||
|
|
|
@ -27,6 +27,17 @@
|
|||
|
||||
namespace refbackrt {
|
||||
|
||||
struct RtValue;
|
||||
|
||||
// Base class for any RefCounted object type
|
||||
class RefTarget {
|
||||
protected:
|
||||
template <typename T> friend class Ref;
|
||||
mutable std::atomic<size_t> refCount;
|
||||
|
||||
constexpr RefTarget() noexcept : refCount(0) {}
|
||||
};
|
||||
|
||||
// Reference-counted handle to a type with a `refCount` member.
|
||||
template <typename T> class Ref {
|
||||
public:
|
||||
|
@ -36,7 +47,7 @@ public:
|
|||
Ref(T *rawPtr) {
|
||||
assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!");
|
||||
ptr = rawPtr;
|
||||
ptr->refCount += 1;
|
||||
incref(ptr);
|
||||
}
|
||||
Ref(const Ref &other) {
|
||||
ptr = other.ptr;
|
||||
|
@ -73,11 +84,14 @@ public:
|
|||
int debugGetRefCount() { return ptr->refCount; }
|
||||
|
||||
private:
|
||||
friend struct RtValue;
|
||||
static void incref(T *ptr) {
|
||||
if (!ptr)
|
||||
return;
|
||||
ptr->refCount += 1;
|
||||
}
|
||||
|
||||
friend struct RtValue;
|
||||
static void decref(T *ptr) {
|
||||
if (!ptr)
|
||||
return;
|
||||
|
@ -96,7 +110,7 @@ enum class ElementType : std::int32_t {
|
|||
std::int32_t getElementTypeByteSize(ElementType type);
|
||||
|
||||
// Representation of a tensor.
|
||||
class Tensor {
|
||||
class Tensor : public RefTarget {
|
||||
public:
|
||||
// Due to tail-allocated objects, this struct should never be directly
|
||||
// constructed.
|
||||
|
@ -132,9 +146,6 @@ private:
|
|||
auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
|
||||
return MutableArrayRef<std::int32_t>(tail, rank);
|
||||
}
|
||||
// Reference count management.
|
||||
template <typename T> friend class Ref;
|
||||
std::atomic<int> refCount{0};
|
||||
|
||||
ElementType elementType;
|
||||
// The number of dimensions of this Tensor.
|
||||
|
@ -150,6 +161,153 @@ private:
|
|||
// Sizes are tail-allocated.
|
||||
};
|
||||
|
||||
// RtValue is a generic tagged union used to hold all value types
|
||||
// The tag determines the type, and the payload represents the stored
|
||||
// contents of an object. If an object is not trivially destructible,
|
||||
// then it must be refcounted and must have a refCount.
|
||||
#define NPCOMP_FORALL_PRIM_TAGS(_) \
|
||||
_(None) \
|
||||
_(Bool) \
|
||||
_(Int) \
|
||||
_(Double)
|
||||
|
||||
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
|
||||
|
||||
#define NPCOMP_FORALL_TAGS(_) \
|
||||
NPCOMP_FORALL_PRIM_TAGS(_) \
|
||||
NPCOMP_FORALL_REF_TAGS(_)
|
||||
|
||||
struct RtValue final {
|
||||
|
||||
RtValue() : payload{0}, tag(Tag::None) {}
|
||||
|
||||
// Bool
|
||||
RtValue(bool b) : tag(Tag::Bool) { payload.asBool = b; }
|
||||
bool isBool() const { return Tag::Bool == tag; }
|
||||
bool toBool() const {
|
||||
assert(isBool());
|
||||
return payload.asBool;
|
||||
}
|
||||
|
||||
// Int
|
||||
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
|
||||
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
|
||||
bool isInt() const { return Tag::Int == tag; }
|
||||
bool toInt() const {
|
||||
assert(isInt());
|
||||
return payload.asInt;
|
||||
}
|
||||
|
||||
// Double
|
||||
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
|
||||
bool isDouble() const { return Tag::Double == tag; }
|
||||
bool toDouble() const {
|
||||
assert(isDouble());
|
||||
return payload.asDouble;
|
||||
}
|
||||
|
||||
// Tensor
|
||||
RtValue(Ref<Tensor> tensor) : tag(Tag::Tensor) {
|
||||
payload.asVoidPtr = reinterpret_cast<void *>(tensor.takePtr());
|
||||
}
|
||||
bool isTensor() const { return Tag::Tensor == tag; }
|
||||
Ref<Tensor> toTensor() const {
|
||||
assert(isTensor());
|
||||
return Ref<Tensor>(reinterpret_cast<Tensor *>(payload.asVoidPtr));
|
||||
}
|
||||
|
||||
// Ref
|
||||
bool isRef() const {
|
||||
#define DEFINE_IS_REF(x) \
|
||||
if (is##x()) { \
|
||||
return true; \
|
||||
}
|
||||
NPCOMP_FORALL_REF_TAGS(DEFINE_IS_REF)
|
||||
#undef DEFINE_IS_REF
|
||||
return false;
|
||||
}
|
||||
|
||||
// RtValue (downcast)
|
||||
const RtValue &toRtValue() const { return *this; }
|
||||
RtValue &toRtValue() { return *this; }
|
||||
|
||||
// Stringify tag for debugging.
|
||||
StringRef tagKind() const {
|
||||
switch (tag) {
|
||||
#define DEFINE_CASE(x) \
|
||||
case Tag::x: \
|
||||
return #x;
|
||||
NPCOMP_FORALL_TAGS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
// TODO(brycearden): Print tag here
|
||||
return "InvalidTag!";
|
||||
}
|
||||
|
||||
RtValue(const RtValue &rhs) : RtValue(rhs.payload, rhs.tag) {
|
||||
if (isRef()) {
|
||||
#define DEFINE_INCREF(x) \
|
||||
if (is##x()) { \
|
||||
Ref<x>::incref(static_cast<x *>(payload.asVoidPtr)); \
|
||||
return; \
|
||||
}
|
||||
NPCOMP_FORALL_REF_TAGS(DEFINE_INCREF)
|
||||
#undef DEFINE_INCREF
|
||||
assert(false && "Unsupported RtValue type");
|
||||
}
|
||||
}
|
||||
RtValue(RtValue &&rhs) noexcept : RtValue() { swap(rhs); }
|
||||
|
||||
RtValue &operator=(RtValue &&rhs) & noexcept {
|
||||
RtValue(std::move(rhs)).swap(*this); // this also sets rhs to None
|
||||
return *this;
|
||||
}
|
||||
RtValue &operator=(RtValue const &rhs) & {
|
||||
RtValue(rhs).swap(*this);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~RtValue() {
|
||||
if (isRef()) {
|
||||
#define DEFINE_DECREF(x) \
|
||||
if (is##x()) { \
|
||||
Ref<x>::decref(static_cast<x *>(payload.asVoidPtr)); \
|
||||
return; \
|
||||
}
|
||||
NPCOMP_FORALL_REF_TAGS(DEFINE_DECREF)
|
||||
#undef DEFINE_DECREF
|
||||
assert(false && "Unsupported RtValue type");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void swap(RtValue &rhs) {
|
||||
std::swap(payload, rhs.payload);
|
||||
std::swap(tag, rhs.tag);
|
||||
}
|
||||
|
||||
// NOTE: Runtime tags are intentionally private.
|
||||
// Please use the helper functions above to query information about the type
|
||||
// of a RtValue.
|
||||
enum class Tag : std::uint32_t {
|
||||
#define DEFINE_TAG(x) x,
|
||||
NPCOMP_FORALL_TAGS(DEFINE_TAG)
|
||||
#undef DEFINE_TAG
|
||||
};
|
||||
|
||||
union Payload {
|
||||
bool asBool;
|
||||
int64_t asInt;
|
||||
double asDouble;
|
||||
void *asVoidPtr;
|
||||
};
|
||||
|
||||
RtValue(Payload pl, Tag tag) : payload(pl), tag(tag) {}
|
||||
|
||||
Payload payload;
|
||||
Tag tag;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module loading.
|
||||
// This is the main entry point that users interact with.
|
||||
|
@ -172,7 +330,7 @@ constexpr static int kMaxArity = 20;
|
|||
// Low-level invocation API. The number of inputs and outputs should be correct
|
||||
// and match the results of getMetadata.
|
||||
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
|
||||
ArrayRef<Ref<Tensor>> inputs, MutableArrayRef<Ref<Tensor>> outputs);
|
||||
ArrayRef<RtValue> inputs, MutableArrayRef<RtValue> outputs);
|
||||
|
||||
// Metadata for function `functionName`.
|
||||
//
|
||||
|
|
|
@ -22,6 +22,7 @@ using llvm::Twine;
|
|||
using refback::JITModule;
|
||||
using refbackrt::Ref;
|
||||
using refbackrt::Tensor;
|
||||
using refbackrt::RtValue;
|
||||
|
||||
template <typename T>
|
||||
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
|
||||
|
@ -106,18 +107,18 @@ void npcomp::python::defineBackendRefJitModule(py::module &m) {
|
|||
[](JITModule &self, std::string functionName,
|
||||
std::vector<py::buffer> inputs) {
|
||||
// Prepare inputs.
|
||||
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
|
||||
inputTensors.reserve(inputs.size());
|
||||
llvm::SmallVector<RtValue, 4> inputValues;
|
||||
inputValues.reserve(inputs.size());
|
||||
for (py::buffer &inputBuffer : inputs) {
|
||||
inputTensors.push_back(copyBufferToTensor(inputBuffer));
|
||||
inputValues.push_back(copyBufferToTensor(inputBuffer));
|
||||
}
|
||||
|
||||
auto outputs = checkError(self.invoke(functionName, inputTensors),
|
||||
auto outputs = checkError(self.invoke(functionName, inputValues),
|
||||
"error invoking JIT function: ");
|
||||
std::vector<py::array> outputArrays;
|
||||
outputArrays.reserve(outputs.size());
|
||||
for (Ref<Tensor> &outputTensor : outputs) {
|
||||
outputArrays.push_back(wrapTensorAsArray(outputTensor));
|
||||
for (RtValue &outputTensor : outputs) {
|
||||
outputArrays.push_back(wrapTensorAsArray(outputTensor.toTensor()));
|
||||
}
|
||||
return outputArrays;
|
||||
},
|
||||
|
|
|
@ -73,14 +73,14 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
|
|||
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
||||
}
|
||||
|
||||
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
||||
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||
JITModule::invoke(llvm::StringRef functionName,
|
||||
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs) {
|
||||
llvm::ArrayRef<refbackrt::RtValue> inputs) {
|
||||
refbackrt::FunctionMetadata metadata;
|
||||
if (refbackrt::failed(refbackrt::getMetadata(
|
||||
descriptor, toRefbackrt(functionName), metadata)))
|
||||
return make_string_error("unknown function: " + Twine(functionName));
|
||||
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> outputs(
|
||||
SmallVector<refbackrt::RtValue, 6> outputs(
|
||||
metadata.numOutputs);
|
||||
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
||||
return make_string_error("invoking '" + Twine(functionName) +
|
||||
|
|
|
@ -131,7 +131,7 @@ Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
|
|||
auto *tensor = static_cast<Tensor *>(
|
||||
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
|
||||
|
||||
tensor->refCount.store(0);
|
||||
tensor->refCount = 0;
|
||||
tensor->elementType = type;
|
||||
tensor->rank = extents.size();
|
||||
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
|
||||
|
@ -172,8 +172,8 @@ static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
|
|||
}
|
||||
|
||||
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||
StringRef functionName, ArrayRef<Ref<Tensor>> inputs,
|
||||
MutableArrayRef<Ref<Tensor>> outputs) {
|
||||
StringRef functionName, ArrayRef<RtValue> inputs,
|
||||
MutableArrayRef<RtValue> outputs) {
|
||||
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
||||
assert(descriptor && "unknown function name");
|
||||
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
||||
|
@ -191,7 +191,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
// more complex though (and maybe impossible given the current abstractions).
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
inputUnrankedMemrefs[i] =
|
||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
|
||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
||||
}
|
||||
// Create a type-erased list of "packed inputs" to pass to the
|
||||
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
||||
|
@ -224,7 +224,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
||||
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
||||
elementType);
|
||||
outputs[i] = Ref<Tensor>(tensor);
|
||||
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
||||
}
|
||||
|
||||
// Now, we just need to free all the UnrankedMemref's that we created.
|
||||
|
|
|
@ -58,14 +58,15 @@ convertAttrToTensor(Attribute attr) {
|
|||
return make_string_error("unhandled argument");
|
||||
}
|
||||
|
||||
static Expected<SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
||||
static Expected<SmallVector<refbackrt::RtValue, 6>>
|
||||
createInputs(ArrayRef<StringRef> argValues) {
|
||||
MLIRContext context;
|
||||
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> ret;
|
||||
SmallVector<refbackrt::RtValue, 6> ret;
|
||||
for (auto argValue : argValues) {
|
||||
auto attr = parseAttribute(argValue, &context);
|
||||
if (!attr)
|
||||
return make_string_error(Twine("could not parse arg value: ") + argValue);
|
||||
// TODO(brycearden): Handle multiple input types
|
||||
auto expectedTensor = convertAttrToTensor(attr);
|
||||
if (!expectedTensor)
|
||||
return expectedTensor.takeError();
|
||||
|
@ -111,11 +112,12 @@ static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
|
|||
attr.print(os);
|
||||
}
|
||||
|
||||
static void printOutputs(ArrayRef<refbackrt::Ref<refbackrt::Tensor>> outputs,
|
||||
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
|
||||
llvm::raw_ostream &os) {
|
||||
for (auto output : llvm::enumerate(outputs)) {
|
||||
assert(output.value().isTensor() && "only tensor outputs are supported.");
|
||||
os << "output #" << output.index() << ": ";
|
||||
printOutput(*output.value(), os);
|
||||
printOutput(*output.value().toTensor().get(), os);
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue