mirror of https://github.com/llvm/torch-mlir
Add some much-needed comments around refbackrt::invoke.
This code is really tricky, and was not commented.pull/136/head
parent
46aa6d0a24
commit
955fd3eeda
|
@ -178,22 +178,46 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
assert(descriptor && "unknown function name");
|
assert(descriptor && "unknown function name");
|
||||||
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
||||||
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
|
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
|
||||||
|
|
||||||
|
// We haven't committed to using "vector" in this runtime code, so use
|
||||||
|
// a fixed-sized array.
|
||||||
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
|
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
|
||||||
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
|
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
|
||||||
std::array<void *, kMaxArity * 2> packedInputs;
|
std::array<void *, kMaxArity * 2> packedInputs;
|
||||||
std::array<void *, kMaxArity> packedOutputs;
|
std::array<void *, kMaxArity> packedOutputs;
|
||||||
|
|
||||||
|
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
|
||||||
|
// TODO: Avoid the deep copy. It makes the later lifetime management code
|
||||||
|
// more complex though (and maybe impossible given the current abstractions).
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
inputUnrankedMemrefs[i] =
|
inputUnrankedMemrefs[i] =
|
||||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
|
convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
|
||||||
}
|
}
|
||||||
|
// Create a type-erased list of "packed inputs" to pass to the
|
||||||
|
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
||||||
|
// one LLVM/C ABI argument to the underlying function.
|
||||||
|
//
|
||||||
|
// The ABI lowering on StandardToLLVM conversion side will
|
||||||
|
// "explode" the unranked memref descriptors on the underlying function
|
||||||
|
// into separate arguments for the rank and pointer-to-descriptor.
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
||||||
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
||||||
}
|
}
|
||||||
|
// Create a type-erased list of "packed output" to pass to the
|
||||||
|
// LLVM/C ABI wrapper function.
|
||||||
|
//
|
||||||
|
// Due to how StandardToLLVM lowering works, each packedOutput pointer
|
||||||
|
// corresponds to a single UnrankedMemref (not "exploded").
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Actually invoke the function!
|
||||||
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
|
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
|
||||||
|
|
||||||
|
// Copy out the result data into refbackrt::Tensor's.
|
||||||
|
// TODO: Avoid needing to make a deep copy.
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
// TODO: Have compiler emit the element type in the metadata.
|
// TODO: Have compiler emit the element type in the metadata.
|
||||||
auto elementType = ElementType::F32;
|
auto elementType = ElementType::F32;
|
||||||
|
@ -202,6 +226,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
elementType);
|
elementType);
|
||||||
outputs[i] = Ref<Tensor>(tensor);
|
outputs[i] = Ref<Tensor>(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now, we just need to free all the UnrankedMemref's that we created.
|
||||||
|
// This is complicated by the fact that multiple input/output UnrankedMemref's
|
||||||
|
// can end up with the same backing buffer (`allocatedPtr`), and we need
|
||||||
|
// to avoid double-freeing.
|
||||||
|
// Output buffers might alias any other input or output buffer.
|
||||||
|
// Input buffers are guaranteed to not alias each other.
|
||||||
|
|
||||||
|
// Free the output buffers.
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||||
// Multiple returned memrefs can point into the same underlying
|
// Multiple returned memrefs can point into the same underlying
|
||||||
|
@ -215,6 +248,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
if (!bufferNeedsFreeing)
|
if (!bufferNeedsFreeing)
|
||||||
std::free(allocatedPtr);
|
std::free(allocatedPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Free the input buffers.
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||||
bool bufferNeedsFreeing = true;
|
bool bufferNeedsFreeing = true;
|
||||||
|
@ -234,12 +269,14 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
std::free(allocatedPtr);
|
std::free(allocatedPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Free the output descriptors.
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
// The LLVM lowering guarantees that each returned unranked memref
|
// The LLVM lowering guarantees that each returned unranked memref
|
||||||
// descriptor is separately malloc'ed, so no need to do anything special
|
// descriptor is separately malloc'ed, so no need to do anything special
|
||||||
// like we had to do for the allocatedPtr's.
|
// like we had to do for the allocatedPtr's.
|
||||||
std::free(outputUnrankedMemrefs[i].descriptor);
|
std::free(outputUnrankedMemrefs[i].descriptor);
|
||||||
}
|
}
|
||||||
|
// Free the input descriptors.
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
std::free(inputUnrankedMemrefs[i].descriptor);
|
std::free(inputUnrankedMemrefs[i].descriptor);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue