Add some much-needed comments around refbackrt::invoke.

This code is really tricky, and was not commented.
pull/136/head
Sean Silva 2020-11-25 15:38:11 -08:00
parent 46aa6d0a24
commit 955fd3eeda
1 changed files with 37 additions and 0 deletions

View File

@ -178,22 +178,46 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
assert(descriptor && "unknown function name"); assert(descriptor && "unknown function name");
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity"); assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity"); assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
// We haven't committed to using "vector" in this runtime code, so use
// a fixed-sized array.
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs; std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs; std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
std::array<void *, kMaxArity * 2> packedInputs; std::array<void *, kMaxArity * 2> packedInputs;
std::array<void *, kMaxArity> packedOutputs; std::array<void *, kMaxArity> packedOutputs;
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
// TODO: Avoid the deep copy. It makes the later lifetime management code
// more complex though (and maybe impossible given the current abstractions).
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
inputUnrankedMemrefs[i] = inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].get()); convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
} }
// Create a type-erased list of "packed inputs" to pass to the
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
// one LLVM/C ABI argument to the underlying function.
//
// The ABI lowering on StandardToLLVM conversion side will
// "explode" the unranked memref descriptors on the underlying function
// into separate arguments for the rank and pointer-to-descriptor.
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank); packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor); packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
} }
// Create a type-erased list of "packed output" to pass to the
// LLVM/C ABI wrapper function.
//
// Due to how StandardToLLVM lowering works, each packedOutput pointer
// corresponds to a single UnrankedMemref (not "exploded").
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]); packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
} }
// Actually invoke the function!
descriptor->functionPtr(packedInputs.data(), packedOutputs.data()); descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
// Copy out the result data into refbackrt::Tensor's.
// TODO: Avoid needing to make a deep copy.
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
// TODO: Have compiler emit the element type in the metadata. // TODO: Have compiler emit the element type in the metadata.
auto elementType = ElementType::F32; auto elementType = ElementType::F32;
@ -202,6 +226,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
elementType); elementType);
outputs[i] = Ref<Tensor>(tensor); outputs[i] = Ref<Tensor>(tensor);
} }
// Now, we just need to free all the UnrankedMemref's that we created.
// This is complicated by the fact that multiple input/output UnrankedMemref's
// can end up with the same backing buffer (`allocatedPtr`), and we need
// to avoid double-freeing.
// Output buffers might alias any other input or output buffer.
// Input buffers are guaranteed to not alias each other.
// Free the output buffers.
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr; void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
// Multiple returned memrefs can point into the same underlying // Multiple returned memrefs can point into the same underlying
@ -215,6 +248,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
if (!bufferNeedsFreeing) if (!bufferNeedsFreeing)
std::free(allocatedPtr); std::free(allocatedPtr);
} }
// Free the input buffers.
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr; void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
bool bufferNeedsFreeing = true; bool bufferNeedsFreeing = true;
@ -234,12 +269,14 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
std::free(allocatedPtr); std::free(allocatedPtr);
} }
// Free the output descriptors.
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
// The LLVM lowering guarantees that each returned unranked memref // The LLVM lowering guarantees that each returned unranked memref
// descriptor is separately malloc'ed, so no need to do anything special // descriptor is separately malloc'ed, so no need to do anything special
// like we had to do for the allocatedPtr's. // like we had to do for the allocatedPtr's.
std::free(outputUnrankedMemrefs[i].descriptor); std::free(outputUnrankedMemrefs[i].descriptor);
} }
// Free the input descriptors.
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
std::free(inputUnrankedMemrefs[i].descriptor); std::free(inputUnrankedMemrefs[i].descriptor);
} }