From 955fd3eedae7c877264d4e497c58697f7a485b4d Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 25 Nov 2020 15:38:11 -0800 Subject: [PATCH] Add some much-needed comments around refbackrt::invoke. This code is really tricky, and was not commented. --- lib/RefBackend/Runtime/Runtime.cpp | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/lib/RefBackend/Runtime/Runtime.cpp b/lib/RefBackend/Runtime/Runtime.cpp index c397f201c..d92dc0de7 100644 --- a/lib/RefBackend/Runtime/Runtime.cpp +++ b/lib/RefBackend/Runtime/Runtime.cpp @@ -178,22 +178,46 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, assert(descriptor && "unknown function name"); assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity"); assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity"); + + // We haven't committed to using "vector" in this runtime code, so use + // a fixed-sized array. std::array inputUnrankedMemrefs; std::array outputUnrankedMemrefs; std::array packedInputs; std::array packedOutputs; + + // Deepcopy the refbackrt::Tensor's into UnrankedMemref's. + // TODO: Avoid the deep copy. It makes the later lifetime management code + // more complex though (and maybe impossible given the current abstractions). for (int i = 0, e = inputs.size(); i < e; i++) { inputUnrankedMemrefs[i] = convertRefbackrtTensorToUnrankedMemref(inputs[i].get()); } + // Create a type-erased list of "packed inputs" to pass to the + // LLVM/C ABI wrapper function. Each packedInput pointer corresponds to + // one LLVM/C ABI argument to the underlying function. + // + // The ABI lowering on StandardToLLVM conversion side will + // "explode" the unranked memref descriptors on the underlying function + // into separate arguments for the rank and pointer-to-descriptor. for (int i = 0, e = inputs.size(); i < e; i++) { packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank); packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor); } + // Create a type-erased list of "packed output" to pass to the + // LLVM/C ABI wrapper function. + // + // Due to how StandardToLLVM lowering works, each packedOutput pointer + // corresponds to a single UnrankedMemref (not "exploded"). for (int i = 0, e = outputs.size(); i < e; i++) { packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]); } + + // Actually invoke the function! descriptor->functionPtr(packedInputs.data(), packedOutputs.data()); + + // Copy out the result data into refbackrt::Tensor's. + // TODO: Avoid needing to make a deep copy. for (int i = 0, e = outputs.size(); i < e; i++) { // TODO: Have compiler emit the element type in the metadata. auto elementType = ElementType::F32; @@ -202,6 +226,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, elementType); outputs[i] = Ref(tensor); } + + // Now, we just need to free all the UnrankedMemref's that we created. + // This is complicated by the fact that multiple input/output UnrankedMemref's + // can end up with the same backing buffer (`allocatedPtr`), and we need + // to avoid double-freeing. + // Output buffers might alias any other input or output buffer. + // Input buffers are guaranteed to not alias each other. + + // Free the output buffers. for (int i = 0, e = outputs.size(); i < e; i++) { void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr; // Multiple returned memrefs can point into the same underlying @@ -215,6 +248,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, if (!bufferNeedsFreeing) std::free(allocatedPtr); } + + // Free the input buffers. for (int i = 0, e = inputs.size(); i < e; i++) { void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr; bool bufferNeedsFreeing = true; @@ -234,12 +269,14 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, std::free(allocatedPtr); } + // Free the output descriptors. for (int i = 0, e = outputs.size(); i < e; i++) { // The LLVM lowering guarantees that each returned unranked memref // descriptor is separately malloc'ed, so no need to do anything special // like we had to do for the allocatedPtr's. std::free(outputUnrankedMemrefs[i].descriptor); } + // Free the input descriptors. for (int i = 0, e = inputs.size(); i < e; i++) { std::free(inputUnrankedMemrefs[i].descriptor); }