mirror of https://github.com/llvm/torch-mlir
Add some much-needed comments around refbackrt::invoke.
This code is really tricky, and was not commented.pull/136/head
parent
46aa6d0a24
commit
955fd3eeda
|
@ -178,22 +178,46 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
assert(descriptor && "unknown function name");
|
||||
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
||||
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
|
||||
|
||||
// We haven't committed to using "vector" in this runtime code, so use
|
||||
// a fixed-sized array.
|
||||
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
|
||||
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
|
||||
std::array<void *, kMaxArity * 2> packedInputs;
|
||||
std::array<void *, kMaxArity> packedOutputs;
|
||||
|
||||
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
|
||||
// TODO: Avoid the deep copy. It makes the later lifetime management code
|
||||
// more complex though (and maybe impossible given the current abstractions).
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
inputUnrankedMemrefs[i] =
|
||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
|
||||
}
|
||||
// Create a type-erased list of "packed inputs" to pass to the
|
||||
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
||||
// one LLVM/C ABI argument to the underlying function.
|
||||
//
|
||||
// The ABI lowering on StandardToLLVM conversion side will
|
||||
// "explode" the unranked memref descriptors on the underlying function
|
||||
// into separate arguments for the rank and pointer-to-descriptor.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
||||
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
||||
}
|
||||
// Create a type-erased list of "packed output" to pass to the
|
||||
// LLVM/C ABI wrapper function.
|
||||
//
|
||||
// Due to how StandardToLLVM lowering works, each packedOutput pointer
|
||||
// corresponds to a single UnrankedMemref (not "exploded").
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
||||
}
|
||||
|
||||
// Actually invoke the function!
|
||||
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
|
||||
|
||||
// Copy out the result data into refbackrt::Tensor's.
|
||||
// TODO: Avoid needing to make a deep copy.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
// TODO: Have compiler emit the element type in the metadata.
|
||||
auto elementType = ElementType::F32;
|
||||
|
@ -202,6 +226,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
elementType);
|
||||
outputs[i] = Ref<Tensor>(tensor);
|
||||
}
|
||||
|
||||
// Now, we just need to free all the UnrankedMemref's that we created.
|
||||
// This is complicated by the fact that multiple input/output UnrankedMemref's
|
||||
// can end up with the same backing buffer (`allocatedPtr`), and we need
|
||||
// to avoid double-freeing.
|
||||
// Output buffers might alias any other input or output buffer.
|
||||
// Input buffers are guaranteed to not alias each other.
|
||||
|
||||
// Free the output buffers.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||
// Multiple returned memrefs can point into the same underlying
|
||||
|
@ -215,6 +248,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
if (!bufferNeedsFreeing)
|
||||
std::free(allocatedPtr);
|
||||
}
|
||||
|
||||
// Free the input buffers.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||
bool bufferNeedsFreeing = true;
|
||||
|
@ -234,12 +269,14 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
std::free(allocatedPtr);
|
||||
}
|
||||
|
||||
// Free the output descriptors.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
// The LLVM lowering guarantees that each returned unranked memref
|
||||
// descriptor is separately malloc'ed, so no need to do anything special
|
||||
// like we had to do for the allocatedPtr's.
|
||||
std::free(outputUnrankedMemrefs[i].descriptor);
|
||||
}
|
||||
// Free the input descriptors.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
std::free(inputUnrankedMemrefs[i].descriptor);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue