mirror of https://github.com/llvm/torch-mlir
[refbackrt] Update Invoke API to support more than just Tensor's (#181)
parent
8f9d4f917d
commit
e7a8fd76e2
|
@ -40,9 +40,9 @@ public:
|
||||||
fromCompiledModule(mlir::ModuleOp module,
|
fromCompiledModule(mlir::ModuleOp module,
|
||||||
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
||||||
|
|
||||||
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||||
invoke(llvm::StringRef functionName,
|
invoke(llvm::StringRef functionName,
|
||||||
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs);
|
llvm::ArrayRef<refbackrt::RtValue> inputs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
JITModule();
|
JITModule();
|
||||||
|
|
|
@ -27,6 +27,17 @@
|
||||||
|
|
||||||
namespace refbackrt {
|
namespace refbackrt {
|
||||||
|
|
||||||
|
struct RtValue;
|
||||||
|
|
||||||
|
// Base class for any RefCounted object type
|
||||||
|
class RefTarget {
|
||||||
|
protected:
|
||||||
|
template <typename T> friend class Ref;
|
||||||
|
mutable std::atomic<size_t> refCount;
|
||||||
|
|
||||||
|
constexpr RefTarget() noexcept : refCount(0) {}
|
||||||
|
};
|
||||||
|
|
||||||
// Reference-counted handle to a type with a `refCount` member.
|
// Reference-counted handle to a type with a `refCount` member.
|
||||||
template <typename T> class Ref {
|
template <typename T> class Ref {
|
||||||
public:
|
public:
|
||||||
|
@ -36,7 +47,7 @@ public:
|
||||||
Ref(T *rawPtr) {
|
Ref(T *rawPtr) {
|
||||||
assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!");
|
assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!");
|
||||||
ptr = rawPtr;
|
ptr = rawPtr;
|
||||||
ptr->refCount += 1;
|
incref(ptr);
|
||||||
}
|
}
|
||||||
Ref(const Ref &other) {
|
Ref(const Ref &other) {
|
||||||
ptr = other.ptr;
|
ptr = other.ptr;
|
||||||
|
@ -73,11 +84,14 @@ public:
|
||||||
int debugGetRefCount() { return ptr->refCount; }
|
int debugGetRefCount() { return ptr->refCount; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend struct RtValue;
|
||||||
static void incref(T *ptr) {
|
static void incref(T *ptr) {
|
||||||
if (!ptr)
|
if (!ptr)
|
||||||
return;
|
return;
|
||||||
ptr->refCount += 1;
|
ptr->refCount += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
friend struct RtValue;
|
||||||
static void decref(T *ptr) {
|
static void decref(T *ptr) {
|
||||||
if (!ptr)
|
if (!ptr)
|
||||||
return;
|
return;
|
||||||
|
@ -96,7 +110,7 @@ enum class ElementType : std::int32_t {
|
||||||
std::int32_t getElementTypeByteSize(ElementType type);
|
std::int32_t getElementTypeByteSize(ElementType type);
|
||||||
|
|
||||||
// Representation of a tensor.
|
// Representation of a tensor.
|
||||||
class Tensor {
|
class Tensor : public RefTarget {
|
||||||
public:
|
public:
|
||||||
// Due to tail-allocated objects, this struct should never be directly
|
// Due to tail-allocated objects, this struct should never be directly
|
||||||
// constructed.
|
// constructed.
|
||||||
|
@ -132,9 +146,6 @@ private:
|
||||||
auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
|
auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
|
||||||
return MutableArrayRef<std::int32_t>(tail, rank);
|
return MutableArrayRef<std::int32_t>(tail, rank);
|
||||||
}
|
}
|
||||||
// Reference count management.
|
|
||||||
template <typename T> friend class Ref;
|
|
||||||
std::atomic<int> refCount{0};
|
|
||||||
|
|
||||||
ElementType elementType;
|
ElementType elementType;
|
||||||
// The number of dimensions of this Tensor.
|
// The number of dimensions of this Tensor.
|
||||||
|
@ -150,6 +161,153 @@ private:
|
||||||
// Sizes are tail-allocated.
|
// Sizes are tail-allocated.
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// RtValue is a generic tagged union used to hold all value types
|
||||||
|
// The tag determines the type, and the payload represents the stored
|
||||||
|
// contents of an object. If an object is not trivially destructible,
|
||||||
|
// then it must be refcounted and must have a refCount.
|
||||||
|
#define NPCOMP_FORALL_PRIM_TAGS(_) \
|
||||||
|
_(None) \
|
||||||
|
_(Bool) \
|
||||||
|
_(Int) \
|
||||||
|
_(Double)
|
||||||
|
|
||||||
|
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
|
||||||
|
|
||||||
|
#define NPCOMP_FORALL_TAGS(_) \
|
||||||
|
NPCOMP_FORALL_PRIM_TAGS(_) \
|
||||||
|
NPCOMP_FORALL_REF_TAGS(_)
|
||||||
|
|
||||||
|
struct RtValue final {
|
||||||
|
|
||||||
|
RtValue() : payload{0}, tag(Tag::None) {}
|
||||||
|
|
||||||
|
// Bool
|
||||||
|
RtValue(bool b) : tag(Tag::Bool) { payload.asBool = b; }
|
||||||
|
bool isBool() const { return Tag::Bool == tag; }
|
||||||
|
bool toBool() const {
|
||||||
|
assert(isBool());
|
||||||
|
return payload.asBool;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int
|
||||||
|
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
|
||||||
|
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
|
||||||
|
bool isInt() const { return Tag::Int == tag; }
|
||||||
|
bool toInt() const {
|
||||||
|
assert(isInt());
|
||||||
|
return payload.asInt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Double
|
||||||
|
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
|
||||||
|
bool isDouble() const { return Tag::Double == tag; }
|
||||||
|
bool toDouble() const {
|
||||||
|
assert(isDouble());
|
||||||
|
return payload.asDouble;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensor
|
||||||
|
RtValue(Ref<Tensor> tensor) : tag(Tag::Tensor) {
|
||||||
|
payload.asVoidPtr = reinterpret_cast<void *>(tensor.takePtr());
|
||||||
|
}
|
||||||
|
bool isTensor() const { return Tag::Tensor == tag; }
|
||||||
|
Ref<Tensor> toTensor() const {
|
||||||
|
assert(isTensor());
|
||||||
|
return Ref<Tensor>(reinterpret_cast<Tensor *>(payload.asVoidPtr));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ref
|
||||||
|
bool isRef() const {
|
||||||
|
#define DEFINE_IS_REF(x) \
|
||||||
|
if (is##x()) { \
|
||||||
|
return true; \
|
||||||
|
}
|
||||||
|
NPCOMP_FORALL_REF_TAGS(DEFINE_IS_REF)
|
||||||
|
#undef DEFINE_IS_REF
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RtValue (downcast)
|
||||||
|
const RtValue &toRtValue() const { return *this; }
|
||||||
|
RtValue &toRtValue() { return *this; }
|
||||||
|
|
||||||
|
// Stringify tag for debugging.
|
||||||
|
StringRef tagKind() const {
|
||||||
|
switch (tag) {
|
||||||
|
#define DEFINE_CASE(x) \
|
||||||
|
case Tag::x: \
|
||||||
|
return #x;
|
||||||
|
NPCOMP_FORALL_TAGS(DEFINE_CASE)
|
||||||
|
#undef DEFINE_CASE
|
||||||
|
}
|
||||||
|
// TODO(brycearden): Print tag here
|
||||||
|
return "InvalidTag!";
|
||||||
|
}
|
||||||
|
|
||||||
|
RtValue(const RtValue &rhs) : RtValue(rhs.payload, rhs.tag) {
|
||||||
|
if (isRef()) {
|
||||||
|
#define DEFINE_INCREF(x) \
|
||||||
|
if (is##x()) { \
|
||||||
|
Ref<x>::incref(static_cast<x *>(payload.asVoidPtr)); \
|
||||||
|
return; \
|
||||||
|
}
|
||||||
|
NPCOMP_FORALL_REF_TAGS(DEFINE_INCREF)
|
||||||
|
#undef DEFINE_INCREF
|
||||||
|
assert(false && "Unsupported RtValue type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RtValue(RtValue &&rhs) noexcept : RtValue() { swap(rhs); }
|
||||||
|
|
||||||
|
RtValue &operator=(RtValue &&rhs) & noexcept {
|
||||||
|
RtValue(std::move(rhs)).swap(*this); // this also sets rhs to None
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
RtValue &operator=(RtValue const &rhs) & {
|
||||||
|
RtValue(rhs).swap(*this);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~RtValue() {
|
||||||
|
if (isRef()) {
|
||||||
|
#define DEFINE_DECREF(x) \
|
||||||
|
if (is##x()) { \
|
||||||
|
Ref<x>::decref(static_cast<x *>(payload.asVoidPtr)); \
|
||||||
|
return; \
|
||||||
|
}
|
||||||
|
NPCOMP_FORALL_REF_TAGS(DEFINE_DECREF)
|
||||||
|
#undef DEFINE_DECREF
|
||||||
|
assert(false && "Unsupported RtValue type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void swap(RtValue &rhs) {
|
||||||
|
std::swap(payload, rhs.payload);
|
||||||
|
std::swap(tag, rhs.tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: Runtime tags are intentionally private.
|
||||||
|
// Please use the helper functions above to query information about the type
|
||||||
|
// of a RtValue.
|
||||||
|
enum class Tag : std::uint32_t {
|
||||||
|
#define DEFINE_TAG(x) x,
|
||||||
|
NPCOMP_FORALL_TAGS(DEFINE_TAG)
|
||||||
|
#undef DEFINE_TAG
|
||||||
|
};
|
||||||
|
|
||||||
|
union Payload {
|
||||||
|
bool asBool;
|
||||||
|
int64_t asInt;
|
||||||
|
double asDouble;
|
||||||
|
void *asVoidPtr;
|
||||||
|
};
|
||||||
|
|
||||||
|
RtValue(Payload pl, Tag tag) : payload(pl), tag(tag) {}
|
||||||
|
|
||||||
|
Payload payload;
|
||||||
|
Tag tag;
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Module loading.
|
// Module loading.
|
||||||
// This is the main entry point that users interact with.
|
// This is the main entry point that users interact with.
|
||||||
|
@ -172,7 +330,7 @@ constexpr static int kMaxArity = 20;
|
||||||
// Low-level invocation API. The number of inputs and outputs should be correct
|
// Low-level invocation API. The number of inputs and outputs should be correct
|
||||||
// and match the results of getMetadata.
|
// and match the results of getMetadata.
|
||||||
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
|
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
|
||||||
ArrayRef<Ref<Tensor>> inputs, MutableArrayRef<Ref<Tensor>> outputs);
|
ArrayRef<RtValue> inputs, MutableArrayRef<RtValue> outputs);
|
||||||
|
|
||||||
// Metadata for function `functionName`.
|
// Metadata for function `functionName`.
|
||||||
//
|
//
|
||||||
|
|
|
@ -22,6 +22,7 @@ using llvm::Twine;
|
||||||
using refback::JITModule;
|
using refback::JITModule;
|
||||||
using refbackrt::Ref;
|
using refbackrt::Ref;
|
||||||
using refbackrt::Tensor;
|
using refbackrt::Tensor;
|
||||||
|
using refbackrt::RtValue;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
|
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
|
||||||
|
@ -106,18 +107,18 @@ void npcomp::python::defineBackendRefJitModule(py::module &m) {
|
||||||
[](JITModule &self, std::string functionName,
|
[](JITModule &self, std::string functionName,
|
||||||
std::vector<py::buffer> inputs) {
|
std::vector<py::buffer> inputs) {
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
|
llvm::SmallVector<RtValue, 4> inputValues;
|
||||||
inputTensors.reserve(inputs.size());
|
inputValues.reserve(inputs.size());
|
||||||
for (py::buffer &inputBuffer : inputs) {
|
for (py::buffer &inputBuffer : inputs) {
|
||||||
inputTensors.push_back(copyBufferToTensor(inputBuffer));
|
inputValues.push_back(copyBufferToTensor(inputBuffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outputs = checkError(self.invoke(functionName, inputTensors),
|
auto outputs = checkError(self.invoke(functionName, inputValues),
|
||||||
"error invoking JIT function: ");
|
"error invoking JIT function: ");
|
||||||
std::vector<py::array> outputArrays;
|
std::vector<py::array> outputArrays;
|
||||||
outputArrays.reserve(outputs.size());
|
outputArrays.reserve(outputs.size());
|
||||||
for (Ref<Tensor> &outputTensor : outputs) {
|
for (RtValue &outputTensor : outputs) {
|
||||||
outputArrays.push_back(wrapTensorAsArray(outputTensor));
|
outputArrays.push_back(wrapTensorAsArray(outputTensor.toTensor()));
|
||||||
}
|
}
|
||||||
return outputArrays;
|
return outputArrays;
|
||||||
},
|
},
|
||||||
|
|
|
@ -73,14 +73,14 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
|
||||||
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||||
JITModule::invoke(llvm::StringRef functionName,
|
JITModule::invoke(llvm::StringRef functionName,
|
||||||
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs) {
|
llvm::ArrayRef<refbackrt::RtValue> inputs) {
|
||||||
refbackrt::FunctionMetadata metadata;
|
refbackrt::FunctionMetadata metadata;
|
||||||
if (refbackrt::failed(refbackrt::getMetadata(
|
if (refbackrt::failed(refbackrt::getMetadata(
|
||||||
descriptor, toRefbackrt(functionName), metadata)))
|
descriptor, toRefbackrt(functionName), metadata)))
|
||||||
return make_string_error("unknown function: " + Twine(functionName));
|
return make_string_error("unknown function: " + Twine(functionName));
|
||||||
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> outputs(
|
SmallVector<refbackrt::RtValue, 6> outputs(
|
||||||
metadata.numOutputs);
|
metadata.numOutputs);
|
||||||
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
||||||
return make_string_error("invoking '" + Twine(functionName) +
|
return make_string_error("invoking '" + Twine(functionName) +
|
||||||
|
|
|
@ -131,7 +131,7 @@ Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
|
||||||
auto *tensor = static_cast<Tensor *>(
|
auto *tensor = static_cast<Tensor *>(
|
||||||
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
|
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
|
||||||
|
|
||||||
tensor->refCount.store(0);
|
tensor->refCount = 0;
|
||||||
tensor->elementType = type;
|
tensor->elementType = type;
|
||||||
tensor->rank = extents.size();
|
tensor->rank = extents.size();
|
||||||
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
|
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
|
||||||
|
@ -172,8 +172,8 @@ static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
|
||||||
}
|
}
|
||||||
|
|
||||||
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
StringRef functionName, ArrayRef<Ref<Tensor>> inputs,
|
StringRef functionName, ArrayRef<RtValue> inputs,
|
||||||
MutableArrayRef<Ref<Tensor>> outputs) {
|
MutableArrayRef<RtValue> outputs) {
|
||||||
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
||||||
assert(descriptor && "unknown function name");
|
assert(descriptor && "unknown function name");
|
||||||
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
||||||
|
@ -191,7 +191,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
// more complex though (and maybe impossible given the current abstractions).
|
// more complex though (and maybe impossible given the current abstractions).
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
inputUnrankedMemrefs[i] =
|
inputUnrankedMemrefs[i] =
|
||||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
|
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
||||||
}
|
}
|
||||||
// Create a type-erased list of "packed inputs" to pass to the
|
// Create a type-erased list of "packed inputs" to pass to the
|
||||||
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
||||||
|
@ -224,7 +224,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
||||||
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
||||||
elementType);
|
elementType);
|
||||||
outputs[i] = Ref<Tensor>(tensor);
|
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now, we just need to free all the UnrankedMemref's that we created.
|
// Now, we just need to free all the UnrankedMemref's that we created.
|
||||||
|
|
|
@ -58,14 +58,15 @@ convertAttrToTensor(Attribute attr) {
|
||||||
return make_string_error("unhandled argument");
|
return make_string_error("unhandled argument");
|
||||||
}
|
}
|
||||||
|
|
||||||
static Expected<SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
static Expected<SmallVector<refbackrt::RtValue, 6>>
|
||||||
createInputs(ArrayRef<StringRef> argValues) {
|
createInputs(ArrayRef<StringRef> argValues) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> ret;
|
SmallVector<refbackrt::RtValue, 6> ret;
|
||||||
for (auto argValue : argValues) {
|
for (auto argValue : argValues) {
|
||||||
auto attr = parseAttribute(argValue, &context);
|
auto attr = parseAttribute(argValue, &context);
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return make_string_error(Twine("could not parse arg value: ") + argValue);
|
return make_string_error(Twine("could not parse arg value: ") + argValue);
|
||||||
|
// TODO(brycearden): Handle multiple input types
|
||||||
auto expectedTensor = convertAttrToTensor(attr);
|
auto expectedTensor = convertAttrToTensor(attr);
|
||||||
if (!expectedTensor)
|
if (!expectedTensor)
|
||||||
return expectedTensor.takeError();
|
return expectedTensor.takeError();
|
||||||
|
@ -111,11 +112,12 @@ static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
|
||||||
attr.print(os);
|
attr.print(os);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printOutputs(ArrayRef<refbackrt::Ref<refbackrt::Tensor>> outputs,
|
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
|
||||||
llvm::raw_ostream &os) {
|
llvm::raw_ostream &os) {
|
||||||
for (auto output : llvm::enumerate(outputs)) {
|
for (auto output : llvm::enumerate(outputs)) {
|
||||||
|
assert(output.value().isTensor() && "only tensor outputs are supported.");
|
||||||
os << "output #" << output.index() << ": ";
|
os << "output #" << output.index() << ": ";
|
||||||
printOutput(*output.value(), os);
|
printOutput(*output.value().toTensor().get(), os);
|
||||||
os << "\n";
|
os << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue